In [1]:
import subprocess
import pandas as pd
import os
import sys
import pprint
import local_models.local_models
import logging
import ml_battery.log
from Todd_eeg_utils import *

import rpy2
import numpy as np
import rpy2.robjects.numpy2ri
from rpy2.robjects.packages import importr
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)

In [2]:
data_dir = "/home/brown/disk2/eeg/Phasespace/Phasespace/data/eeg-text" 
transformed_data_dir = "/home/brown/disk2/eeg/transformed_data"

In [3]:
data_info = pd.read_csv(os.path.join(data_dir, "fileinformation.csv"), skiprows=1).iloc[:,2:]

In [4]:
data_info

Unnamed: 0,Unnamed: 2,Number of Records,Time of Seizure,Unnamed: 5,Point of Seizure,250
0,DAT.F00012,3963799,14040,,3510000,3510000
1,DAT.F00013,3632699,12720,,3180000,3180000
2,DAT.F00016,4447824,15960,,3990000,3990000
3,DAT.F00017,1827224,0,,0,0
4,DAT.F00018,2985924,10020,,2505000,2505000
5,DAT.F00019,3692374,12960,,3240000,3240000
6,DAT.F00020,3270974,11280,,2820000,2820000
7,DAT.F00022,2120524,0,,0,0
8,DAT.F00024,3378499,11700,,2925000,2925000
9,DAT.F00026,3370999,0,,0,0


In [5]:
data_info.shape
how_many_epis = len([which for which in range(data_info.shape[0]) if data_info.iloc[which,4]>0])
how_many_epis

40

In [6]:
positive_samples = []
negative_samples = []
for i in range(data_info.shape[0]):
    data_file = data_info.iloc[i,0]
    data_epipoint = data_info.iloc[i,4]
    data_len = data_info.iloc[i,1]
    if data_len > data_epipoint > 0:
        transformed_data_file_dir = os.path.join(transformed_data_dir, data_file)
        transformed_data_files = os.listdir(transformed_data_file_dir)
        negative_data_files = sorted([f for f in transformed_data_files if "negative" in f])
        positive_data_files = sorted([f for f in transformed_data_files if "negative" not in f])
        positive_sample_all_channels = []
        negative_sample_all_channels = []
        for ndf, pdf in zip(negative_data_files, positive_data_files):
            positive_sample_all_channels.append(np.loadtxt(os.path.join(transformed_data_file_dir, pdf))[:,0])
            negative_sample_all_channels.append(np.loadtxt(os.path.join(transformed_data_file_dir, ndf))[:,0])
        positive_samples.append(np.stack(positive_sample_all_channels,axis=1))
        negative_samples.append(np.stack(negative_sample_all_channels,axis=1))


In [7]:
positive_samples[0].shape

(1000, 21)

In [8]:
positive_samples = np.stack(positive_samples)
negative_samples = np.stack(negative_samples)

In [9]:
positive_samples.shape, negative_samples.shape

((39, 1000, 21), (39, 1000, 21))

In [10]:
np.random.seed(0)
indices = list(range(39))
np.random.shuffle(indices)

In [11]:
train_set = indices[:20]
test_set = indices[20:]

In [12]:
positive_train = positive_samples[train_set]
negative_train = negative_samples[train_set]
positive_test = positive_samples[test_set]
negative_test = negative_samples[test_set]
train = np.concatenate((positive_train, negative_train))
test = np.concatenate((positive_test, negative_test))
train_labels = np.concatenate((np.ones(positive_train.shape[0]), np.zeros(negative_train.shape[0])))
test_labels = np.concatenate((np.ones(positive_test.shape[0]), np.zeros(negative_test.shape[0])))

In [13]:
positive_samples.shape

(39, 1000, 21)

In [14]:
rpy2.robjects.numpy2ri.activate()
    
# Set up our R namespaces
R = rpy2.robjects.r
DTW = importr('dtw')

In [17]:
cdists = np.empty((test.shape[0], train.shape[0]))

In [18]:
cdists.shape

(38, 40)

In [19]:
timelog = local_models.local_models.loggin.TimeLogger(
    logger=logger, 
    how_often=1, total=len(train_set)*len(test_set)*4, 
    tag="dtw_matrix")

import gc
# Calculate the alignment vector and corresponding distance
for test_i in range(cdists.shape[0]):
    for train_i in range(cdists.shape[1]):
        with timelog:
            alignment = R.dtw(test[test_i], train[train_i], keep_internals=False, distance_only=True)
            dist = alignment.rx('distance')[0][0]
            print(dist)
            cdists[test_i, train_i] = dist
            gc.collect()
            R('gc()')
            gc.collect()
        

13186.588829274286


  places = np.log10(np.abs(number))


14290.266968102667
12811.776329182972
10860.793892280444
8861.893163924642
13542.868796339331
11016.149584055996
16827.26072261249
8270.241742567574
15727.073927481602
15600.083435197455
10079.51170184304
11903.421416319186
18500.850977255246
11034.26447234431
11090.294905590876
10212.086833934103
10906.105213275448
10499.432905690719
12163.430105579477
16362.516499375715
15464.282295807065
18442.080679907987
15277.481404534587
10526.882705307107
21670.184526621208
14295.908435031872
16632.503603224137
17220.494288423477
17598.277227751754
12190.174175798767
12827.729227071823
7763.549026213526
12789.679210798187
13746.068903965466
12310.287806963523
16991.97077830681
16394.41979837164
15729.591049381146
11894.856071014929
12379.868335901489
14892.006801732176
20327.356238640423
15288.78486139668
15055.569668649257
17305.751109311434
14873.504444766475
29496.638219236396
14608.533691651179
16194.734297379926
5728.802272678272
11969.461476638593
27164.17269732408
12492.302388002805
1025

10416.943401664565
14110.750010493679
10480.98852140896
32507.889305217454
11035.02546971969
14179.096820618377
11726.841522943698
9349.03602327302
17131.547286558853
15058.356624924143
12829.974106642094
12715.416510122068
13406.923853873472
9663.944232900707
10893.342288143378
13143.882917665282
9883.505619303938
12278.791052265582
11697.083159760585
13774.362939492745
12598.946731332617
14150.195224309879
11178.254150336434
23185.352176591132
12780.300357663573
13552.89960702926
12753.73172834109
14805.401581518465
16487.060932678432
13787.477777691576
13938.472258648608
14690.914073434977
16409.928914498007
10802.067571600694
11357.520113287588
11715.198872333245
18138.316042189126
17937.825498697774
15433.262122714681
15620.50825348037
14610.491632620304
14930.823636892845
17718.85616646797
12746.4873820649
11573.275434953426
21423.887600149606
23741.35931548568
14738.266076634556
10659.506318476027
22031.169980802202
11532.980809596498
11422.830498047677
13199.257263819758
18225.

21664.828259008962
11887.312863830619
8008.867376992656
15967.177965963167
15067.821750216404
19554.35315889408
13282.032382800197
12761.1794916531
9767.82596840182
15394.515989692822
7865.505461865393
8060.557723108939
14771.1305568108
11643.860082432093
9780.882888067235
12299.324568555683
8295.375247884049
3272.8099519670645
13530.456901576837
10844.492268158818
29040.725939737993
12915.819945299247
11464.140053981928
12347.28338895037
14403.285811160204
9909.687484803113
10626.207749176066
13685.64789281105
9987.631043250716
16011.82308054114
10335.881132557108
10620.360143232556
13411.006345943442
8251.650584656916
9201.976611690903
10876.352290596958
8721.842025285558
8161.584560326116
9635.034001460293
4972.250166835917
25666.370997511607
8032.343624488772
7972.821552503335
10114.52760520417
13150.198063378462
21904.485568264416
13185.132768706999
11051.589726696518
12717.122037161114
13076.296134875121
7834.286337109625
6648.521782461024
11166.906573224755
6129.45447911442
8105

13575.342154716955
7111.143493142249
7794.004744813033
17512.662956995624
13684.756092720338
5239.509194198223
8513.660528639586
7795.931856893312
10161.753366291721
10488.248085258065
7503.217707928938
10114.920306267231
12272.549066672069
13495.514560958421
12744.258778825597
10952.957571556442
16182.824059031858
11235.062083383102
29825.65000516029
9671.782341032338
14077.04527185887
3802.69457519468
7016.853107900749
11170.548174706473
9077.567830307533
5543.186199557812
10224.450258132198
7853.100920628676
11282.75697452764
11279.381022506193
5704.927078875165
36718.68788070313
33967.15776264878
43099.90861778805
42864.52680634306
44220.452865634776
37937.63851843583
36330.47550273082
43564.53237280618
41739.22921698639
38779.94142176498
33878.63224831661
43139.37710209662
62863.06124077336
36539.18259356129
45607.26233966271
47984.69265048473
45591.220549183905
34759.20290404368
39071.475325051215
38928.54620948642
33662.400634071215
34435.21569362244
33358.50977599618
36100.5198

In [25]:
import sklearn.metrics

In [26]:
cdists.shape

(38, 40)

In [27]:
cdists_file = os.path.join(transformed_data_dir, "dtw_cdists.dat")
if "cdists" in globals() and not os.path.exists(cdists_file):
    np.savetxt(cdists_file, cdists)
else:
    cdists = np.loadtxt(cdists_file)

In [28]:
test_labels.shape

(38,)

In [29]:
np.argmin(cdists, axis=1).shape

(38,)

In [30]:
cm = sklearn.metrics.confusion_matrix(test_labels, train_labels[np.argmin(cdists, axis=1)])

In [31]:
print(cm)

[[16  3]
 [10  9]]


In [34]:
pd.DataFrame(np.round(cdists/10**3,0))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,30,31,32,33,34,35,36,37,38,39
0,13.0,14.0,13.0,11.0,9.0,14.0,11.0,17.0,8.0,16.0,...,12.0,13.0,8.0,13.0,14.0,12.0,17.0,16.0,16.0,12.0
1,12.0,15.0,20.0,15.0,15.0,17.0,15.0,29.0,15.0,16.0,...,7.0,8.0,20.0,11.0,7.0,15.0,9.0,11.0,11.0,7.0
2,12.0,14.0,16.0,14.0,12.0,14.0,15.0,20.0,12.0,17.0,...,9.0,9.0,12.0,12.0,7.0,11.0,9.0,15.0,15.0,9.0
3,10.0,9.0,10.0,8.0,3.0,10.0,8.0,25.0,10.0,9.0,...,11.0,13.0,10.0,9.0,12.0,10.0,15.0,8.0,9.0,12.0
4,10.0,9.0,12.0,10.0,9.0,10.0,7.0,29.0,9.0,10.0,...,10.0,13.0,12.0,12.0,12.0,12.0,15.0,10.0,9.0,11.0
5,13.0,9.0,9.0,9.0,8.0,9.0,10.0,24.0,13.0,9.0,...,15.0,16.0,13.0,12.0,16.0,12.0,18.0,10.0,11.0,15.0
6,10.0,11.0,12.0,12.0,11.0,11.0,8.0,25.0,10.0,11.0,...,13.0,15.0,14.0,13.0,14.0,16.0,16.0,9.0,8.0,13.0
7,12.0,8.0,14.0,14.0,14.0,12.0,11.0,29.0,19.0,10.0,...,15.0,15.0,26.0,12.0,15.0,19.0,17.0,7.0,10.0,14.0
8,11.0,14.0,15.0,14.0,13.0,14.0,15.0,11.0,9.0,19.0,...,11.0,11.0,10.0,13.0,11.0,15.0,14.0,15.0,16.0,10.0
9,12.0,11.0,11.0,10.0,8.0,10.0,12.0,22.0,11.0,14.0,...,16.0,16.0,12.0,13.0,16.0,13.0,20.0,14.0,16.0,15.0


In [35]:
np.argmin(cdists, axis=1)

array([32, 10, 14,  4, 26,  4, 26, 22, 19,  4, 27, 11, 32, 34, 14, 26, 28,
       20,  4, 20, 30, 14, 24,  6, 25, 26, 22, 34, 22, 27,  8, 32, 34, 30,
       37, 20, 20, 20])

In [41]:
pd.DataFrame(cm, index=[["true"]*2,["-","+"]], columns=[["pred"]*2, ["-", "+"]])

Unnamed: 0_level_0,Unnamed: 1_level_0,pred,pred
Unnamed: 0_level_1,Unnamed: 1_level_1,-,+
True,-,16,3
True,+,10,9


In [42]:
acc = np.sum(np.diag(cm))/np.sum(cm)
prec = cm[1,1]/np.sum(cm[:,1])
rec = cm[1,1]/np.sum(cm[1])
acc,prec,rec

(0.6578947368421053, 0.75, 0.47368421052631576)