In [1]:
import torch 
import matplotlib.pyplot as plt
import numpy  
from visualization import * 
import torch.nn as nn
from utils import *
from dataloader import *

import sklearn.metrics as metrics 

from txlstm_szpool import *
from baselines import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#edit as required

device = 'cpu'
data_root = '/home/deeksha/EEG_Sz/miccai23/'
manifest = read_manifest('/home/deeksha/EEG_Sz/miccai23/'+ 'data/tuh_single_windowed_manifest.csv', ',')


n=15
use_cuda = False

In [3]:
def detection_metrics(dataloader, model):
    model.eval()
    running_loss = 0.0 
    running_corrects = 0.0
    true_positives = 0
    true_negatives = 0
    positives = 0
    length = 0
    detloss = nn.CrossEntropyLoss()
    trues = []
    proba = []
    for batch_idx, data in enumerate(dataloader): 
        
        inputs = data['buffers']
        labels = data['sz_labels']
        labels = labels.reshape(-1).long()
   
        inputs = inputs.to(torch.DoubleTensor())
        chn_output, _, _ = model(inputs)[:3]
        outputs = chn_output.reshape(-1,2)
        p = F.softmax(outputs, dim=-1).detach().cpu().numpy()[:, 1]
        proba.append(p)
        trues.append(labels.detach().cpu().numpy().reshape(-1))
        del inputs
        
       
        pred = torch.argmax(outputs.data, 1).long()  #cross entropy
        
        length += (pred.shape)[0]
        running_corrects += pred.eq(labels.data.view_as(pred)).cpu().sum()
        true_positives += ((pred==1)&(labels.data.view_as(pred)==1)).cpu().sum()   
        true_negatives += ((pred==0)&(labels.data.view_as(pred)==0)).cpu().sum()  
        positives += labels.data.view_as(pred).cpu().sum()

        del pred, labels
        torch.cuda.empty_cache()
    del model   
    epoch_loss = running_loss / len(dataloader)
    acc = round(running_corrects.item()/ length, 3)
    sens = round(true_positives.item()/positives.item() , 3)
    spec = round(true_negatives.item()/ (length-positives.item()), 3)
    trues = np.concatenate(trues)
    proba = np.concatenate(proba)
    auc = metrics.roc_auc_score(trues, proba)
    return epoch_loss, acc, sens, spec, auc

# Detection results

## Window level

In [7]:

detmodels = {'txlstm_szpool':{'auc': [], 'sens':[], 'spec':[]},
             'txlstm_maxpool':{'auc': [], 'sens':[], 'spec':[]},
             'tgcn_szpool':{'auc': [], 'sens':[], 'spec':[]} ,
             'sztrack_szpool':{'auc': [], 'sens':[], 'spec':[]}, 
             'cnnblstm':{'auc': [], 'sens':[], 'spec':[]}
             }

In [None]:
res_dict = []
for cvfold in range(0,15,1):
    foldroot = 'final_models/fold'+str(cvfold)+'/'
    allfiles = os.listdir(foldroot)
    modelfiles = os.listdir(foldroot)
    testfile = list(filter(lambda f:f.startswith('pts_test'), allfiles))[0]
    
    val_pts = np.load(foldroot+testfile)
    val_set =  pretrainLoader(data_root, val_pts, manifest, addNoise=False, 
                                          input_mask=None, ablate=False, permute=False, normalize=True)
    validation_loader = DataLoader(val_set, batch_size=1, shuffle=True)
                
    
    for modelname in list(detmodels.keys()):
        print(cvfold, modelname)
        
        
        statedict = torch.load(foldroot+modelfile)
        
        if modelname == 'cnnblstm':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            model = CNN_BLSTM()
            model.load_state_dict(statedict)
      
        elif modelname == 'txlstm_maxpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn =foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='maxpool')
            model.load_state_dict(statedict)

        elif modelname == 'tgcn_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn = foldroot+list(filter(lambda f:f.startswith('tgcn'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'tgcn', pooltype='szpool')
            model.load_state_dict(statedict)
        
        elif modelname =='sztrack_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
                model = sztrack()
                model.load_state_dict(statedict)
        else:
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
            prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
            model.load_state_dict(statedict)
    
        model.double()
        loss, acc, sens, spec, auc = detection_metrics(validation_loader, model)
        
        detmodels[modelname]['auc'].append(auc)
        detmodels[modelname]['sens'].append(sens)
        detmodels[modelname]['spec'].append(spec)
        
        del model
    #break
    del validation_loader
        

## seizure level

In [14]:
ablate = {'sztrack_szpool':{'fpr': [], 'sens':[], 'lat':[]}}

In [11]:
class MovingAverage(nn.Module):
    def __init__(self, winlen = 21):
        super(MovingAverage,self).__init__()
        self.winlen = winlen
        self.layer = nn.AvgPool1d(kernel_size=winlen, stride=1, padding=int((winlen-1)/2), count_include_pad=True)
    
    def forward(self, x):
        return self.layer(x)
    
    
sig = nn.Sigmoid()

In [12]:
thresrange = np.arange(0.3, 0.75, 0.05)
smoother = MovingAverage(31)

In [None]:
for cvfold in range(0,15,1):
    foldroot = 'final_models/fold'+str(cvfold)+'/'
    allfiles = os.listdir(foldroot)
    #modelfiles = list(filter(lambda f:f.endswith('.tar'), allfiles))
    modelfiles = os.listdir(foldroot+'models/')
    testfile = list(filter(lambda f:f.startswith('pts_test'), allfiles))[0]
    
    val_pts = np.load(foldroot+testfile)
    val_set =  pretrainLoader(data_root, val_pts, manifest, addNoise=False, 
                                          input_mask=None, ablate=False, permute=False, normalize=True)
    test_loader = DataLoader(val_set, batch_size=1, shuffle=True)
                
    
        
    for modelname in list(detmodels.keys()):
        print(cvfold, modelname)
        
        
        statedict = torch.load(foldroot+modelfile)
        
        if modelname == 'cnnblstm':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            model = CNN_BLSTM()
            model.load_state_dict(statedict)
      
        elif modelname == 'txlstm_maxpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn =foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='maxpool')
            model.load_state_dict(statedict)

        elif modelname == 'tgcn_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn = foldroot+list(filter(lambda f:f.startswith('tgcn'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'tgcn', pooltype='szpool')
            model.load_state_dict(statedict)
        
        elif modelname =='sztrack_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
                model = sztrack()
                model.load_state_dict(statedict)
        else:
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
            prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
            model.load_state_dict(statedict)
    
        model.double()
        final_thres = 0.2
        
        for j, thres in enumerate(thresrange):
            fp_count = 0
            time = 0
            for data in validation_loader:
                x = data['buffers'].double()
                ytrue = data['sz_labels'].reshape(-1).detach().cpu()
                nsz = x.shape[1]
                y, _, _ = model(x)[:3]
                proba = F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(nsz, -1)
                p_smooth = smoother(proba.detach().cpu()[None, :, :])[:, ].reshape(-1)
                ypred = torch.zeros_like(ytrue)
                ypred[p_smooth >= thres] = 1 
                
                time += nsz*600 / 3600
                fp_count += (ypred[ytrue==0]==1).sum()
            fpr = fp_count/time
            if fpr <= 120:
                final_thres = thres
                break


        fp_count, time = 0, 0
        latency = []
        tp = 0 
        p = 0
        FPM = []

        time_control = 0
        for data in test_loader:
            x = data['buffers'][:, :1].double()
            ytrue = data['sz_labels'][:, :1].reshape(-1).detach().numpy()
            nsz = x.shape[1]
            y, _, _ = model(x)[:3]
            proba = F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(nsz, -1)
            p_smooth = smoother(proba.detach().cpu()[None, :, :])[:, ].reshape(-1).detach().numpy()
            ypred = np.zeros_like(ytrue)
            ypred[p_smooth > final_thres] = 1 
            
            ypred = ypred.reshape(nsz, -1)
            p_smooth = p_smooth.reshape(nsz, -1)

            
           
            p += nsz
            ytrue = ytrue.reshape(nsz, -1)
                                 
            time += nsz*600 / 3600
            ton = np.argmax(ytrue, 1)
            tonpred = np.argmax(ypred, 1)
            for j in range(nsz):
                s = max(0, ton[j] - 10)
                time_control += s+1
                if s==0:
                    s = 1
                temp = ypred[j, :s] - np.concatenate(([0], ypred[j, : s-1]))
                fp_count += (temp==1).sum()
                ysz = ypred[j]
                tp += int((ysz[ytrue[j]==1]==1).any())
            
                
                temp = ypred[j, ] - np.concatenate(([0], ypred[j, : -1]))
                onsets = np.where(temp==1)[0]
                offsets = np.where(temp==-1)[0]
                if not len(onsets) == 0:
                    
                    for eno in range(len(onsets)):
                        try: 
                            e = offsets[eno] 
                        except:
                            e = -1
                        if (ytrue[j, onsets[eno]:e] == 1).any():
                            break
                    latency.append(onsets[eno] - ton[j])
           
                fpm = ypred[j, :s].sum()/60
                FPM.append(fpm)
        
        del model
        
        detmodels[modelname]['fpr'].append( np.mean(FPM))
        detmodels[modelname]['sens'].append(tp/p)
        detmodels[modelname]['lat'].append(np.mean(latency))

In [16]:
(np.mean(ablate['sztrack_szpool']['fpr']), np.std(ablate['sztrack_szpool']['fpr']))

(2.0665415579802517, 0.8443198665507652)

## Delong test

In [11]:
from delong import *

In [20]:
delongmodels = ['txlstm_szpool', 'cnnblstm', 'tgcn_szpool', 'txlstm_maxpool', 'sztrack']

In [None]:

for cvfold in range(0,15,1):
    foldroot = 'final_models/fold'+str(cvfold)+'/'
    allfiles = os.listdir(foldroot)
    modelfiles = list(filter(lambda f:f.endswith('.tar'), allfiles))
    testfile = list(filter(lambda f:f.startswith('pts_test'), allfiles))[0]
    
    val_pts = np.load(foldroot+testfile)
    val_set =  pretrainLoader(data_root, val_pts, manifest, addNoise=False, 
                                          input_mask=None, ablate=False, permute=False, normalize=True)
    test_loader = DataLoader(val_set, batch_size=1, shuffle=True)
                
    ground = []
    pred_txlstm = []
    pred_tgcn = []
    pred_cnnblstm = []
    pred_max = []
    pred_sztrack = []
    for modelname in list(delongmodels):
   
        
        if modelname == 'cnnblstm':
                modelfile = list(filter(lambda f:f.startswith(modelname+'_pret'), allfiles))[0]
                statedict = torch.load(foldroot+modelfile)
                model_cnnblstm = CNN_BLSTM()
                model_cnnblstm.load_state_dict(statedict)
                model_cnnblstm.double()

        elif modelname == 'tgcn_szpool':
                modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), allfiles))[0]
                statedict = torch.load(foldroot+modelfile)
                prefn = foldroot+list(filter(lambda f:f.startswith('tgcn'+'_pre'), allfiles))[0]
                model_tgcn = txlstm_szpool(transformer_dropout=0.1, 
                                  pretrained = prefn, 
                                  modelname = 'tgcn', pooltype='szpool')
                model_tgcn.load_state_dict(statedict)
                model_tgcn.double()
                
        elif modelname =='sztrack_szpool':
                modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
                statedict = torch.load(foldroot+modelfile)            
                model_sztrack = sztrack()
                model_sztrack.load_state_dict(statedict)
                
        elif modelname == 'txlstm_maxpool':
                modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), allfiles))[0]
                statedict = torch.load(foldroot+modelfile)
                prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
                model_max = txlstm_szpool(transformer_dropout=0.1, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
                model_max.load_state_dict(statedict)
                model_max.double()         
        
        else:
                modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), allfiles))[0]
                statedict = torch.load(foldroot+modelfile)
                prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
                model_txlstm = txlstm_szpool(transformer_dropout=0.1, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
                model_txlstm.load_state_dict(statedict)
                model_txlstm.double()
        
    model_txlstm.eval()
    model_tgcn.eval()
    model_max.eval()
    model_cnnblstm.eval()
    model_sztrack.eval()
    for idx, data in enumerate(test_loader):
            x = data['buffers'].double()
            ytrue = data['sz_labels'].reshape(-1).detach().numpy()
            nsz = x.shape[1]
            
           
            y, _, _ = model_txlstm(x)[:3]
            pred_txlstm.append(F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(-1).detach().numpy())
            del y
        
            
            
            y, _, _ = model_tgcn(x)[:3]
            pred_tgcn.append(F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(-1).detach().numpy())
            del y
            
            
            y, _, _ = model_cnnblstm(x)[:3]
            pred_cnnblstm.append(F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(-1).detach().numpy())
            del y
           
            y, _, _ = model_sztrack(x)[:3]
            pred_sztrack.append(F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(-1).detach().numpy())
            del y
           
            
            y, _, _ = model_max(x)[:3]
            pred_max.append(F.softmax(y.reshape(-1, 2), -1)[:, 1].reshape(-1).detach().numpy())
            del y
            
            ground.append(ytrue)
            

            #if idx==5:
            #    break
    del model_txlstm, model_tgcn,model_cnnblstm, model_sztrack
    pred_txlstm = np.concatenate(pred_txlstm)
    pred_sztrack = np.concatenate(pred_sztrack)
    pred_tgcn = np.concatenate(pred_tgcn)
    pred_cnnblstm = np.concatenate(pred_cnnblstm)
    pred_max = np.concatenate(pred_max)
  
    ground = np.concatenate(ground)
    print('tgcn', np.round(10**delong_roc_test(ground, pred_txlstm, pred_tgcn), 3))
    print('cnnblstm', np.round( 10**delong_roc_test(ground, pred_txlstm, pred_cnnblstm), 3))
    print('sztrack', np.round( 10**delong_roc_test(ground, pred_txlstm, pred_sztrack), 3))
    print('txlstm_max', np.round( 10**delong_roc_test(ground, pred_txlstm, pred_max), 3))
    

# Localization results

In [7]:
chn_neighbours = {0: [1,2,3,4], 
                  1: [0,4,5,6], 
                  2: [0,3,4,7,8], 
                  3: [0,2,4,8,9], 
                  4: [0,1,3,5,9], 
                  5: [1,4,6,9,10],
                  6: [1,4,5,10,11], 
                  7: [2,8,12,13, 17], 
                  8: [2,3,4,7,9,12,13,14], 
                  9: [3,4,5,8,10,13,14,15], 
                 10: [4,5,6,9,11,14,15,16], 
                 11: [6, 10, 15, 16, 18], 
                 12: [7, 8, 13, 17], 
                 13: [7, 8, 9, 12, 14, 17],
                 14: [8,9,10,13,15,17,18],
                 15: [9,10,11,14,16,18], 
                 16: [10,11,15,18], 
                 17: [7,12,13,14,18], 
                 18: [11, 14,15, 16, 17]}

def check_neighborhood(max_chn, onset_map):
    check = False
    for i in range(19):
        if onset_map[i]==1:
            surrounding = chn_neighbours[i]
            if max_chn in surrounding:
                check = True
    return check

def final_loc(psoz, true_onset):
            n = psoz.shape[0]
            m = psoz.max(1).reshape(n,1)
            psoz = psoz/m
            ysoz = psoz.mean(0)
            #ysoz /= ysoz.max()
            max_chn_loc = np.argmax(ysoz)                    
            max_chn_correct = 1 if true_onset[max_chn_loc]==1 else 0
            
            
            if check_neighborhood(max_chn_loc, true_onset) and true_onset.sum()<=4:
                        max_chn_correct =1 
                       
 
  
            Uev = psoz.var(0)

            return ysoz, Uev, max_chn_correct

## patient level

In [4]:

locmodels = {'txlstm_szpool':{'auc': [], 'sens':[], 'spec':[]},
             'txlstm_maxpool':{'auc': [], 'sens':[], 'spec':[]},
             'tgcn_szpool':{'auc': [], 'sens':[], 'spec':[]} ,
             'sztrack_szpool':{'auc': [], 'sens':[], 'spec':[]}
            
             }

In [None]:
for cvfold in range(5,6,1):
    foldroot = 'final_models/fold'+str(cvfold)+'/'
    allfiles = os.listdir(foldroot)
    #modelfiles = list(filter(lambda f:f.endswith('.tar'), allfiles))
    modelfiles = os.listdir(foldroot)
    testfile = list(filter(lambda f:f.startswith('pts_test'), allfiles))[0]
    
    val_pts = np.load(foldroot+testfile)
   
  
    for modelname in list(locmodels.keys()):
        print(cvfold, modelname)
        
        
        statedict = torch.load(foldroot+modelfile)
        

      
        if modelname == 'txlstm_maxpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn =foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='maxpool')
            model.load_state_dict(statedict)

        elif modelname == 'tgcn_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn = foldroot+list(filter(lambda f:f.startswith('tgcn'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'tgcn', pooltype='szpool')
            model.load_state_dict(statedict)
        
        elif modelname =='sztrack_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
                model = sztrack()
                model.load_state_dict(statedict)
        else:
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
            prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
            model.load_state_dict(statedict)
       
        model.double()
        corr_pt = 0
        uncs = []
        for pt in val_pts:
            
            sets =  pretrainLoader(data_root, [pt], manifest, addNoise=False, normalize=True, 
                                   input_mask=None, ablate=False, permute=False)
            loader = DataLoader(sets, batch_size=1, shuffle=True)
            
            
            correct_loc = 0
            tot_sz = 0
            ysoz_all = []
            a_all = []
      
            for idx, data in enumerate(loader):
                    
                    inputs = data['buffers'].double()
                    B,Nsz,T, C, N = inputs.shape
                    chn_pos = torch.arange(C)
                    true_onset = data['onset map'].detach().numpy().reshape(-1)
    
                    output, psoz, _, a = model(inputs)
                    ysoz_all.append(psoz.reshape(Nsz, C).detach().cpu().numpy())
                    a_all.append(a)
                    tot_sz += Nsz
                    
            
           
            ysoz_all = np.concatenate(ysoz_all).reshape(tot_sz, 19)
            psoz = ysoz_all
            ysoz, conf_y, max_chn_correct_y = final_loc(psoz, true_onset)

            
            corr_pt += max_chn_correct_y
            
            uncs.append(conf_y)
            #f, ax = plt.subplots(1,1, figsize = (4,4))
            
            #pos2d = topoplot(ysoz, ax,
            #                    title=pt, zone=1,
            #                   lobe_correct=max_chn_correct_y,
            #                   lat_correct=max_chn_correct_y,
            #                  onset_map=true_onset)
            
            #plt.show()
        print(modelname, cvfold, corr_pt, np.array(uncs).mean() )
        del model
        
        locmodels[modelname]['ptcorr'].append(corr_pt)
        locmodels[modelname]['ptunc'].append(np.array(uncs).max(1).mean())
        
    print('\n')

In [11]:
jsonf = json.dumps(locmodels)

# open file for writing, "w" 
f = open("loc_finetuned_finalresults_nomask.json","w")

# write json object to file
f.write(jsonf)

# close file
f.close()

In [34]:
with open("loc_finetuned_finalresults.json") as json_file:
    temp = json.load(json_file)

In [35]:
for m in list(temp.keys()):

    print(np.round(np.mean(temp[m]['ptcorr']), 3)/24, np.round(np.std(temp[m]['ptcorr']), 3)/24)

0.6930416666666667 0.10675
0.411125 0.075875
0.486125 0.12329166666666667
0.6805416666666666 0.082875
0.6555416666666667 0.11020833333333334


## sz level

In [None]:
n_samples =20
for cvfold in range(0,15,1):
    foldroot = 'final_models/fold'+str(cvfold)+'/'
    allfiles = os.listdir(foldroot)
    #modelfiles = list(filter(lambda f:f.endswith('.tar'), allfiles))
    modelfiles = os.listdir(foldroot+'/txmodels/')
    testfile = list(filter(lambda f:f.startswith('pts_test'), allfiles))[0]
    val_pts = np.load(foldroot+testfile)
    sets =  pretrainLoader(data_root, val_pts, manifest, addNoise=False, normalize=True, 
                                   input_mask=None, ablate=False, permute=False)
    test_loader = DataLoader(sets, batch_size=1, shuffle=True)
    
    val_pts = np.load(foldroot+testfile)
   
  
        for modelname in list(locmodels.keys()):
        print(cvfold, modelname)
        
        
        statedict = torch.load(foldroot+modelfile)
        

      
        if modelname == 'txlstm_maxpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn =foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='maxpool')
            model.load_state_dict(statedict)

        elif modelname == 'tgcn_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)
            prefn = foldroot+list(filter(lambda f:f.startswith('tgcn'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'tgcn', pooltype='szpool')
            model.load_state_dict(statedict)
        
        elif modelname =='sztrack_szpool':
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
                model = sztrack()
                model.load_state_dict(statedict)
        else:
            modelfile = list(filter(lambda f:f.startswith(modelname+'_fine'), modelfiles))[0]
            statedict = torch.load(foldroot+modelfile)            
            prefn = foldroot+list(filter(lambda f:f.startswith('txlstm'+'_pre'), allfiles))[0]
            model = txlstm_szpool(transformer_dropout=0.15, 
                                  pretrained = prefn, 
                                  modelname = 'txlstm', pooltype='szpool')
            model.load
        
        del statedict 
        corr_sz = []
        mc_var = []
        
        for idx, data in enumerate(test_loader):

            ysoz_all = []
            #print(idx)
            
            inputs = data['buffers'].long()
            
            B,Nsz,T, C, N = inputs.shape
            chn_pos = torch.arange(C)
            true_onset = data['onset map'].detach().numpy().reshape(-1)
            #tot_sz += Nsz
            for mcmc in range(n_samples):   
                with torch.no_grad():
                    try:
                        _, psoz, _, _ = model(inputs)
                    except:
                        _, psoz, _, _ = model(inputs.double())
                    ysoz_all.append(psoz.reshape(1, Nsz, C).detach().cpu().numpy())

               
            ysoz_all = np.concatenate(ysoz_all) #.reshape(tot_sz, 19)
            #print(ysoz_all.shape)
            for szn in range(Nsz):
                
                    ysoz, conf_y, max_chn_correct_y = final_loc(ysoz_all[:, szn, :].reshape(-1, 19), true_onset)
                    corr_sz.append(max_chn_correct_y)
                
                    mc_var.append(conf_y)

            #f, ax = plt.subplots(1,10, figsize = (20,2))
            '''
            ysoz, conf_y, max_chn_correct_y = final_loc(psoz[j,:,].reshape(1,-1),
                                                            true_onset)
            pos2d = topoplot(ysoz, ax[j],
                                title='ysoz', zone=1,
                                lobe_correct=max_chn_correct_y,
                                lat_correct=max_chn_correct_y,
                                onset_map=true_onset)
            plt.show()
            #print(Nsz, data['patient numbers'])
            '''
            del ysoz_all
        del model
        
        print(modelname, cvfold, np.array(corr_sz).mean())
        
        locmodels[modelname]['szcorr'].append(np.array(corr_sz).mean())
        locmodels[modelname]['szunc'].append(np.array(mc_var).mean())
        del corr_sz, mc_var
    del test_loader

In [22]:
jsonf = json.dumps(locmodels)

# open file for writing, "w" 
f = open("loc_finetuned_results_final.json","w")

# write json object to file
f.write(jsonf)

# close file
f.close()