In [2]:
import subprocess
import pandas as pd
import os
import sys
import pprint
import local_models.local_models
import logging
import ml_battery.log
from Todd_eeg_utils import *

import rpy2
import numpy as np
import rpy2.robjects.numpy2ri
from rpy2.robjects.packages import importr
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)

In [3]:
data_dir = "/home/brown/disk2/eeg/Phasespace/Phasespace/data/eeg-text" 
transformed_data_dir = "/home/brown/disk2/eeg/transformed_data"

In [4]:
data_info = pd.read_csv(os.path.join(data_dir, "fileinformation.csv"), skiprows=1).iloc[:,2:]

In [5]:
data_info

Unnamed: 0,Unnamed: 2,Number of Records,Time of Seizure,Unnamed: 5,Point of Seizure,250
0,DAT.F00012,3963799,14040,,3510000,3510000
1,DAT.F00013,3632699,12720,,3180000,3180000
2,DAT.F00016,4447824,15960,,3990000,3990000
3,DAT.F00017,1827224,0,,0,0
4,DAT.F00018,2985924,10020,,2505000,2505000
5,DAT.F00019,3692374,12960,,3240000,3240000
6,DAT.F00020,3270974,11280,,2820000,2820000
7,DAT.F00022,2120524,0,,0,0
8,DAT.F00024,3378499,11700,,2925000,2925000
9,DAT.F00026,3370999,0,,0,0


In [6]:
data_info.shape
how_many_epis = len([which for which in range(data_info.shape[0]) if data_info.iloc[which,4]>0])
how_many_epis

40

In [10]:
short_classification_data_dir = os.path.join(data_dir, "shortened_classification_data")
os.makedirs(short_classification_data_dir, exist_ok=1)

In [11]:
subsample_rate=5
gpr_subsample_rate=10
timelog = local_models.local_models.loggin.TimeLogger(
    logger=logger, 
    how_often=1, total=how_many_epis, 
    tag="getting_filtered_data")
for i in range(data_info.shape[0]):
    data_file = data_info.iloc[i,0]
    data_epipoint = data_info.iloc[i,4]
    data_len = data_info.iloc[i,1]
    if data_len > data_epipoint > 0:
        with timelog:
            shortened_data_onset_file = os.path.join(short_classification_data_dir, "{}_onset.dat".format(data_file))
            shortened_data_negative_file = os.path.join(short_classification_data_dir, "{}_negative.dat".format(data_file))
            if not os.path.isfile(shortened_data_onset_file):
                data, data_offset = get_filtered_data(data_file, data_dir)
                data_epipoint = int((data_epipoint - data_offset)/subsample_rate)
                subsampled_dat = data[::subsample_rate]
                HZ = int(SIGNAL_HZ/subsample_rate)
                bandwidth = 2*HZ
                l = HZ*SECONDS_OF_SIGNAL
                n = 2*bandwidth-1

                ictal_rng = (max(0,data_epipoint-l), min(subsampled_dat.shape[0], data_epipoint+l))
                negative_ictal_rng = (max(0, int(data_epipoint/2)-l), min(subsampled_dat.shape[0], int(data_epipoint/2)+l))
                subsample_ictal_rng = (np.array(ictal_rng)/gpr_subsample_rate).astype(int)
                subsample_negative_ictal_rng = (np.array(negative_ictal_rng)/gpr_subsample_rate).astype(int)
                lm_kernel = local_models.local_models.TriCubeKernel(bandwidth=bandwidth)
                index_X = np.arange(subsampled_dat.shape[0]*1.).reshape(-1,1)
                index = local_models.local_models.ConstantDistanceSortedIndex(index_X.flatten())
                exemplar_rng = (HZ*4,HZ*4+n)
                exemplar_X = index_X[slice(*exemplar_rng)]
                exemplar_y = subsampled_dat[slice(*exemplar_rng)]
                ictal_X = index_X[slice(*ictal_rng)]
                ictal_X_gpr_subsampled = index_X[ictal_rng[0] : ictal_rng[1] : gpr_subsample_rate]
                exemplar_X_gpr_subsampled = index_X[exemplar_rng[0] : exemplar_rng[1] : gpr_subsample_rate]
                negative_ictal_X = index_X[slice(*negative_ictal_rng)]
                negative_ictal_X_gpr_subsampled = index_X[negative_ictal_rng[0] : negative_ictal_rng[1] : gpr_subsample_rate]

                np.savetxt(shortened_data_onset_file, subsampled_dat[slice(*ictal_rng)])
                np.savetxt(shortened_data_negative_file, subsampled_dat[slice(*negative_ictal_rng)])


  places = np.log10(np.abs(number))


In [16]:
positive_samples = []
negative_samples = []
for i in range(data_info.shape[0]):
    data_file = data_info.iloc[i,0]
    data_epipoint = data_info.iloc[i,4]
    data_len = data_info.iloc[i,1]
    if data_len > data_epipoint > 0:
        shortened_data_onset_file = os.path.join(short_classification_data_dir, "{}_onset.dat".format(data_file))
        shortened_data_negative_file = os.path.join(short_classification_data_dir, "{}_negative.dat".format(data_file))
        positive_samples.append(np.loadtxt(shortened_data_onset_file))
        negative_samples.append(np.loadtxt(shortened_data_negative_file))


In [17]:
positive_samples = np.stack(positive_samples)
negative_samples = np.stack(negative_samples)

In [18]:
positive_samples.shape, negative_samples.shape

((39, 10000, 21), (39, 10000, 21))

In [19]:
np.random.seed(0)
indices = list(range(39))
np.random.shuffle(indices)

In [44]:
indices

[4,
 28,
 29,
 33,
 34,
 25,
 10,
 22,
 11,
 27,
 18,
 15,
 2,
 38,
 20,
 36,
 16,
 35,
 8,
 13,
 5,
 17,
 14,
 32,
 7,
 31,
 1,
 26,
 12,
 30,
 24,
 6,
 23,
 21,
 19,
 9,
 37,
 3,
 0]

In [20]:
train_set = indices[:20]
test_set = indices[20:]

In [21]:
positive_train = positive_samples[train_set]
negative_train = negative_samples[train_set]
positive_test = positive_samples[test_set]
negative_test = negative_samples[test_set]
train = np.concatenate((positive_train, negative_train))
test = np.concatenate((positive_test, negative_test))
train_labels = np.concatenate((np.ones(positive_train.shape[0]), np.zeros(negative_train.shape[0])))
test_labels = np.concatenate((np.ones(positive_test.shape[0]), np.zeros(negative_test.shape[0])))

In [30]:
train.shape

(40, 10000, 21)

In [80]:
positive_samples.shape

(39, 10000, 21)

In [44]:
rpy2.robjects.numpy2ri.activate()
    
# Set up our R namespaces
R = rpy2.robjects.r
DTW = importr('dtw')

In [45]:
gc.collect()
R('gc()')

0,1,2,3,4,5,6
310089.0,251455233.0,16.6,...,844850459.0,31.8,6445.7


In [46]:
cdists = np.empty((test.shape[0], train.shape[0]))

In [47]:
cdists.shape

(38, 40)

In [49]:
timelog = local_models.local_models.loggin.TimeLogger(
    logger=logger, 
    how_often=1, total=len(train_set)*len(test_set)*4, 
    tag="dtw_matrix")

import gc
# Calculate the alignment vector and corresponding distance
for test_i in range(cdists.shape[0]):
    for train_i in range(cdists.shape[1]):
        with timelog:
            alignment = R.dtw(test[test_i], train[train_i], keep_internals=False, distance_only=True)
            dist = alignment.rx('distance')[0][0]
            print(dist)
            cdists[test_i, train_i] = dist
            gc.collect()
            R('gc()')
            gc.collect()
        

24914157.263403278
24564923.343759436
28960553.958771948
24425609.975232482
24141469.55875256
27718232.88415126
22420659.2334804
40032357.483874656
25783131.124216273
25355728.348947853
20618836.501595125
27946719.11779935
49071843.0651387
26442474.87366295
26429087.41376496
25897086.96157728
28705266.855245594
22855749.501263052
24167231.25683878
25284114.9812834
20072891.37236362
21390135.139175557
20515940.75361534
22473806.626501862
24170142.088724557
21224738.917030375
21453514.311813354
39541334.96674075
20707180.50428427
21944400.106313795
21632597.318683628
21639363.62007214
28735052.87101651
24364327.708778568
22201394.487392865
26427208.854715798
24420007.880339425
20450730.498402208
21340125.761142876
21444891.949008815
11824270.548342286
11522099.213107187
16449448.479426112
11516916.400600381
11797930.64109011
14976368.565582922
9486625.019526776
26626574.399414204
13069887.099328876
12692009.61633277
6255855.588854969
13716421.892283717
36571044.5048895
13034659.26767112


20171774.532634158
23338361.93462889
17540225.18337415
35464455.0783256
21618880.81729438
20916866.089438334
15835245.124590443
23623750.527103763
45267334.08323404
21628352.315648746
22100062.85458563
21706454.78806527
24254252.97207035
18555456.598038953
20071686.449032195
20810801.017315287
15260563.345120966
16596033.29458983
15735248.78717716
17786426.299842272
20397489.807156127
16300520.36033075
16616682.840938699
35509586.04737333
15796922.692926275
17113484.22203264
16874846.15776655
16840572.16588708
25326009.59412055
19639478.13252784
17548935.337987565
22104531.020682633
19765996.964210283
15571595.153859217
16430679.706217043
16723958.43621881
22167374.37168263
21785233.8429915
26493731.212823715
21682833.376808934
21942272.162986983
25108099.412702512
19867761.256890588
35850028.96516068
23182589.621599544
22886285.09581543
17461873.320253957
24539668.98125822
45821754.392097615
22426245.596892286
22719590.343323812
21030654.983994447
25409610.654173993
20482853.793001004

13289778.01108031
31553458.845546618
16903095.51165349
15698541.303829005
12148554.59854004
19275169.597729433
40853283.04187412
16515862.798032422
17667438.88618003
17040662.25322826
20127797.43189672
13591133.438956412
15284183.948656127
16714577.058075389
11209390.301931204
12141959.361157693
11528277.292388942
13242527.388803747
13874538.60881836
12075724.492472105
12422629.199365158
31272309.70228264
11569297.764331896
12673075.240189316
13036555.021746987
13228431.853231026
19915020.73537892
15118959.944641499
13646567.761309609
17172338.248420134
15949461.479944099
11521880.41715722
12378812.720736375
12901471.45511305
11500876.883005515
11199048.386056503
14761013.277051937
11026723.888212044
11596462.772415029
14006039.25321744
7516515.248952692
27177495.99971836
12641227.238250751
10800642.93439433
7373387.762508173
14795175.421227196
36955763.04824403
13476960.603920504
13554953.046958774
13332747.971265333
15812201.53901328
9676666.671208784
10920522.36815289
12340174.97059

14459789.391473232
8122337.861114944
15212415.05622413
38457944.59459394
15137395.338256633
13192642.38670921
14582305.775844455
16203484.773998465
12167386.640286129
13536996.296208344
12906948.467428373
8889089.56715665
10264985.706154829
9351989.21196538
11425720.853706
14010223.390026167
9986911.374025248
10308375.098549588
28816827.439633787
9327976.624272907
10824063.676326452
8609609.481912155
9467197.55809529
18749128.411846094
12776883.555672562
9356276.411441425
15035208.726358168
12027857.959167749
9200956.589984175
10158194.051185096
8791659.451011805
8786613.38012208
8366351.3201194005
13393779.803071832
8495078.296835678
8841600.313037755
11896838.182017589
6356657.50726777
23459399.723959357
10038936.152489923
9526640.869303536
3944696.60045037
11493064.09558982
34283603.0780325
10133975.234838182
10425662.427230252
10357988.612426378
12418121.344980108
6987229.417475753
8540666.710532045
8830104.466190558
3721116.248893408
5178349.090847653
4108641.490525395
6166058.188

In [7]:
import sklearn.metrics

In [12]:
cdists_file = os.path.join(short_classification_data_dir, "dtw_cdists.dat")
if "cdists" in globals() and not os.path.exists(cdists_file):
    np.savetxt(cdists_file, cdists)
else:
    cdists = np.loadtxt(cdists_file)

In [13]:
np.argmin(cdists, axis=1).shape

(38,)

In [43]:
sum(np.argmin(cdists, axis=1)[:19] == 10) + sum(np.argmin(cdists, axis=1)[:19] == 20)

17

In [42]:
sum(np.argmin(cdists, axis=1)[19:] == 10) + sum(np.argmin(cdists, axis=1)[19:] == 20)

15

In [22]:
cm = sklearn.metrics.confusion_matrix(test_labels, train_labels[np.argmin(cdists, axis=1)])

In [23]:
print(cm)

[[14  5]
 [14  5]]


In [24]:
pd.DataFrame(np.round(cdists/10**6,0))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,30,31,32,33,34,35,36,37,38,39
0,25.0,25.0,29.0,24.0,24.0,28.0,22.0,40.0,26.0,25.0,...,22.0,22.0,29.0,24.0,22.0,26.0,24.0,20.0,21.0,21.0
1,12.0,12.0,16.0,12.0,12.0,15.0,9.0,27.0,13.0,13.0,...,7.0,7.0,17.0,11.0,7.0,13.0,10.0,7.0,8.0,7.0
2,19.0,19.0,23.0,19.0,19.0,22.0,17.0,34.0,20.0,20.0,...,14.0,15.0,24.0,18.0,15.0,20.0,17.0,15.0,16.0,14.0
3,14.0,13.0,18.0,14.0,13.0,17.0,12.0,30.0,15.0,14.0,...,11.0,11.0,19.0,13.0,12.0,16.0,14.0,9.0,11.0,11.0
4,15.0,15.0,18.0,14.0,15.0,17.0,11.0,31.0,16.0,14.0,...,12.0,12.0,20.0,15.0,13.0,17.0,15.0,11.0,11.0,12.0
5,20.0,18.0,21.0,18.0,18.0,21.0,17.0,35.0,21.0,18.0,...,17.0,17.0,25.0,17.0,17.0,19.0,20.0,15.0,16.0,17.0
6,12.0,12.0,16.0,12.0,12.0,15.0,9.0,27.0,13.0,12.0,...,9.0,9.0,17.0,12.0,10.0,14.0,12.0,7.0,8.0,9.0
7,11.0,10.0,15.0,10.0,11.0,14.0,8.0,26.0,12.0,11.0,...,7.0,7.0,16.0,10.0,8.0,12.0,10.0,6.0,7.0,7.0
8,21.0,21.0,26.0,21.0,21.0,24.0,19.0,34.0,22.0,22.0,...,17.0,17.0,26.0,20.0,18.0,23.0,20.0,17.0,18.0,17.0
9,29.0,27.0,32.0,27.0,27.0,30.0,26.0,44.0,30.0,28.0,...,25.0,25.0,33.0,27.0,26.0,29.0,28.0,24.0,25.0,25.0


In [25]:
np.argmin(cdists, axis=1)

array([20, 10, 10, 20, 20, 22, 20, 22, 10, 20, 20, 20, 20, 10, 10, 20, 20,
       20, 20, 20, 10, 10, 20, 26, 25, 20, 22, 10, 22, 20, 20, 20, 10, 10,
       20, 20, 20, 20])

In [76]:
cols = [0,5,10,20,22,25,26,32,37]
pd.DataFrame(np.round(cdists[:,cols]/10**6,0),columns=cols)

Unnamed: 0,0,5,10,20,22,25,26,32,37
0,25.0,28.0,21.0,20.0,21.0,21.0,21.0,29.0,20.0
1,12.0,15.0,6.0,7.0,7.0,8.0,8.0,17.0,7.0
2,19.0,22.0,14.0,14.0,15.0,15.0,16.0,24.0,15.0
3,14.0,17.0,10.0,9.0,9.0,10.0,11.0,19.0,9.0
4,15.0,17.0,11.0,10.0,11.0,11.0,10.0,20.0,11.0
5,20.0,21.0,16.0,15.0,14.0,14.0,16.0,25.0,15.0
6,12.0,15.0,8.0,7.0,8.0,8.0,8.0,17.0,7.0
7,11.0,14.0,6.0,6.0,5.0,6.0,7.0,16.0,6.0
8,21.0,24.0,16.0,16.0,17.0,18.0,18.0,26.0,17.0
9,29.0,30.0,24.0,24.0,24.0,24.0,25.0,33.0,24.0


In [28]:
pd.DataFrame(cm, index=[["true"]*2,["-","+"]], columns=[["pred"]*2, ["-", "+"]])

Unnamed: 0_level_0,Unnamed: 1_level_0,pred,pred
Unnamed: 0_level_1,Unnamed: 1_level_1,-,+
True,-,14,5
True,+,14,5


In [29]:
acc = np.sum(np.diag(cm))/np.sum(cm)
prec = cm[1,1]/np.sum(cm[:,1])
rec = cm[1,1]/np.sum(cm[1])
acc,prec,rec

(0.5, 0.5, 0.2631578947368421)