In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import EarlyStopping
from random import shuffle
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import losses
import time
import tensorflow as tf
import os
import keras
import pandas as pd
import pickle

In [None]:
with open("downstream_labeling_data.pickle", 'rb') as f:
    DATA = pickle.load(f)

In [None]:
mirrored_strategy = tf.distribute.MirroredStrategy()

In [None]:
DATA["specEEG"][0].shape

(1, 1, 22, 55, 114)

In [None]:
class CustomDataset(tf.keras.utils.Sequence):
    def __init__(self,specEEG,labels,batch_size,target_size=(1, 55, 114),shuffle=False,n_classes=1):
        self.batch_size = batch_size
        self.dim        = target_size
        self.labels     = labels
        self.specEEG   = specEEG
        self.n_classes  = n_classes
        self.shuffle    = shuffle
        self.c          = 0
        self.on_epoch_end() 

    def __len__(self):
        # returns the number of batches
        return int(np.floor(len(self.specEEG) / self.batch_size))

    def __getitem__(self, index):
        # returns one batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, y = self.__data_generation(indexes)
        return X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.specEEG))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    
    def __data_generation(self, list_IDs_temp):
        for i in range(22):
            globals()['X%d'%i] = np.empty((self.batch_size, *self.dim))
        y = np.empty((self.batch_size), dtype=int)
        
        X_list = [X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21]
        
        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            spec   = self.specEEG[ID]
            for Xnum, X in enumerate(X_list):
                X[i,]  = spec[0][0][Xnum]

            # Store class
            y[i] = self.labels[ID]

            self.c +=1
        return X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21,  y #keras.utils.to_categorical(y, num_classes=self.n_classes)

    
class Generator(keras.utils.Sequence):
    """Wrapper of two generatos for the combined input model"""

    def __init__(self, X, Y, batch_size,target_size=(1, 55, 114)):
        self.genX = CustomDataset(X, Y, batch_size=batch_size,shuffle=False,target_size=target_size)

    def __len__(self):
        return self.genX.__len__()

    def __getitem__(self, index):
        X1_batch, X2_batch, X3_batch, X4_batch, X5_batch, X6_batch, X7_batch, X8_batch, X9_batch, X10_batch, X11_batch, X12_batch, X13_batch, X14_batch, X15_batch, X16_batch, X17_batch, X18_batch, X19_batch, X20_batch, X21_batch, X22_batch,Y_batch = self.genX.__getitem__(index)
        
        return [X1_batch, X2_batch, X3_batch, X4_batch, X5_batch, X6_batch, X7_batch, X8_batch, X9_batch, X10_batch, X11_batch, X12_batch, X13_batch, X14_batch, X15_batch, X16_batch,  X17_batch, X18_batch, X19_batch, X20_batch, X21_batch, X22_batch], Y_batch

In [None]:
def createModel_L_G_CoCL():
    pre_encoder = keras.models.load_model('./saved_models/SCL_model/L_CoCL/model.h5')
    pre_encoder2 = keras.models.load_model('./saved_models/SCL_model/G_CoCL/model.h5')
    pre_encoder.trainable = True
    pre_encoder2.trainable = True

    pre_encoder._name = 'model1'
    pre_encoder2._name = 'model2'

    inp1  = Input(shape=(1,55,114))
    h1 = pre_encoder(inp1,training=True)
    h1_ = pre_encoder2(inp1,training=True)

    inp2  = Input(shape=(1,55,114))
    h2 = pre_encoder(inp2,training=True)
    h2_ = pre_encoder2(inp2,training=True)

    inp3  = Input(shape=(1,55,114))
    h3 = pre_encoder(inp3,training=True)
    h3_ = pre_encoder2(inp3,training=True)

    inp4  = Input(shape=(1,55,114))
    h4 = pre_encoder(inp4,training=True)
    h4_ = pre_encoder2(inp4,training=True)

    inp5  = Input(shape=(1,55,114))
    h5 = pre_encoder(inp5,training=True)
    h5_ = pre_encoder2(inp5,training=True)

    inp6  = Input(shape=(1,55,114))
    h6 = pre_encoder(inp6,training=True)
    h6_ = pre_encoder2(inp6,training=True)

    inp7  = Input(shape=(1,55,114))
    h7 = pre_encoder(inp7,training=True) 
    h7_ = pre_encoder2(inp7,training=True)   

    inp8  = Input(shape=(1,55,114))
    h8 = pre_encoder(inp8,training=True)
    h8_ = pre_encoder2(inp8,training=True)

    inp9  = Input(shape=(1,55,114))
    h9 = pre_encoder(inp9,training=True)
    h9_ = pre_encoder2(inp9,training=True)

    inp10  = Input(shape=(1,55,114))
    h10 = pre_encoder(inp10,training=True)
    h10_ = pre_encoder2(inp10,training=True)

    inp11  = Input(shape=(1,55,114))
    h11 = pre_encoder(inp11,training=True)
    h11_ = pre_encoder2(inp11,training=True)

    inp12  = Input(shape=(1,55,114))
    h12 = pre_encoder(inp12,training=True)
    h12_ = pre_encoder2(inp12,training=True)

    inp13  = Input(shape=(1,55,114))
    h13 = pre_encoder(inp13,training=True)
    h13_ = pre_encoder2(inp13,training=True)

    inp14  = Input(shape=(1,55,114))
    h14 = pre_encoder(inp14,training=True)
    h14_ = pre_encoder2(inp14,training=True)  

    inp15  = Input(shape=(1,55,114))
    h15 = pre_encoder(inp15,training=True)
    h15_ = pre_encoder2(inp15,training=True)

    inp16  = Input(shape=(1,55,114))
    h16 = pre_encoder(inp16,training=True)
    h16_ = pre_encoder2(inp16,training=True)

    inp17  = Input(shape=(1,55,114))
    h17 = pre_encoder(inp17,training=True)
    h17_ = pre_encoder2(inp17,training=True)

    inp18  = Input(shape=(1,55,114))
    h18 = pre_encoder(inp18,training=True)
    h18_ = pre_encoder2(inp18,training=True)

    inp19  = Input(shape=(1,55,114))
    h19 = pre_encoder(inp19,training=True)
    h19_ = pre_encoder2(inp19,training=True)

    inp20  = Input(shape=(1,55,114))
    h20 = pre_encoder(inp20,training=True)
    h20_ = pre_encoder2(inp20,training=True)

    inp21  = Input(shape=(1,55,114))
    h21 = pre_encoder(inp21,training=True)
    h21_ = pre_encoder2(inp21,training=True)

    inp22  = Input(shape=(1,55,114)) 
    h22 = pre_encoder(inp22,training=True)
    h22_ = pre_encoder2(inp22,training=True)

    concat1 = Concatenate(axis=-3)([h1, h2, h3, h4, h5, h6, h7, h8, h9, h10, h11, h12, h13, h14, h15, h16, h17, h18, h19, h20, h21, h22])
    concat2 = Concatenate(axis=-3)([h1_, h2_, h3_, h4_, h5_, h6_, h7_, h8_, h9_, h10_, h11_, h12_, h13_, h14_, h15_, h16_, h17_, h18_, h19_, h20_, h21_, h22])

    reshape1 = Reshape((1, 22, 55, 114))(concat1)
    reshape2 = Reshape((1, 22, 55, 114))(concat2)

    
    #C1
    lay1 = Conv3D(16, (22, 5, 5), strides=(1, 2, 2), padding='valid',activation='relu',data_format= "channels_first")(reshape1)
    lay1_ = Conv3D(16, (22, 5, 5), strides=(1, 2, 2), padding='valid',activation='relu',data_format= "channels_first")(reshape2)
    lay2 = keras.layers.MaxPooling3D(pool_size=(1, 2, 2),data_format= "channels_first",  padding='same')(lay1)
    lay2_ = keras.layers.MaxPooling3D(pool_size=(1, 2, 2),data_format= "channels_first",  padding='same')(lay1_)
    lay3 = BatchNormalization()(lay2)
    lay3_ = BatchNormalization()(lay2_)
    
    #C2
    lay4 = Conv3D(32, (1, 3, 3), strides=(1, 1,1), padding='valid',data_format= "channels_first",  activation='relu')(lay3)
    lay4_ = Conv3D(32, (1, 3, 3), strides=(1, 1,1), padding='valid',data_format= "channels_first",  activation='relu')(lay3_)#incertezza se togliere padding
    lay5 = keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", )(lay4)
    lay5_ = keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", )(lay4_)
    lay6 = BatchNormalization()(lay5)
    lay6_ = BatchNormalization()(lay5_)
    
    #C3
    lay7 =Conv3D(64, (1,3, 3), strides=(1, 1,1), padding='valid',data_format= "channels_first",  activation='relu')(lay6)
    lay7_ =Conv3D(64, (1,3, 3), strides=(1, 1,1), padding='valid',data_format= "channels_first",  activation='relu')(lay6_)#incertezza se togliere padding
    lay8 =keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", )(lay7)
    lay8_ =keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", )(lay7_)
    lay9 =BatchNormalization()(lay8)
    lay9_ =BatchNormalization()(lay8_)
    
    lay10 = Flatten()(lay9)
    lay10_ = Flatten()(lay9_)
    concat = Concatenate()([lay10,lay10_])
    
    lay11 = Dense(256, activation='linear')(concat)
    ouputs = Dense(1, activation='sigmoid')(lay11)
    
    ftmodel = Model([inp1, inp2, inp3, inp4, inp5, inp6, inp7, inp8, inp9, inp10, inp11, inp12, inp13, inp14, inp15, inp16, inp17, inp18, inp19, inp20, inp21, inp22], outputs=ouputs)
    
    opt_adam = keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    ftmodel.compile(loss='BinaryCrossentropy', optimizer=opt_adam, metrics=['accuracy'])
    
    return ftmodel

In [None]:
save_dir = './saved_models/SCL_fine_tuning/'

historys = {}

from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True)

def to_one_or_zero(bool):
    return 1 if bool else 0

patients = ["01", "02", "03", "05", "09", "10", "18", "21", "23"]

target_sensitive = []
target_accuracy = []
target_specificity = []
for i in range(len(patients)):
    sen = []
    acc = []
    spe = []
    
    DATA_s = pd.DataFrame(DATA)
    DATA_set = DATA_s[DATA_s["numPAZ"] == patients[i]]

    spec = DATA_set['specEEG']
    labels = DATA_set['LABEL']
    spec   = np.array(spec)
    labels = np.array(labels)
    
    for index, (train_indices, val_indices) in enumerate(skf.split(spec, labels)):
        print("Training on fold " + str(index+1))
        X_train, X_val = spec[train_indices], spec[val_indices]
        y_train, y_val = labels[train_indices],labels[val_indices]
        print(X_train.shape, y_train.shape,X_val.shape, y_val.shape)

        TRAIN = Generator(X_train,y_train,batch_size=64,target_size=(1, 55, 114))
        VALID = Generator(X_val,y_val,batch_size=64,target_size=(1, 55, 114))
        
        with mirrored_strategy.scope():
            model = createModel_L_G_CoCL()
        
        early_stopping = EarlyStopping(monitor = 'val_loss', patience = 10, mode = 'auto')
        checkpointer = keras.callbacks.ModelCheckpoint(save_dir+patients[i]+'model.h5', monitor='val_loss',verbose=1, save_best_only=True, save_weights_only =True)

        # FIT THE MODEL
        history = model.fit(TRAIN,
            epochs=150,
            validation_data=VALID,
            callbacks=[checkpointer, early_stopping])

        historys= history.history
        
        X_test = DATA_test['specEEG']
        y_test = DATA_test['LABEL']
        X_test   = np.array(X_test)
        y_test = np.array(y_test)
    
        VALID_ = Generator(X_test,y_test,batch_size=1,target_size=(1, 55, 114))
        y_scores = model.predict(VALID_)
        y_scores = list(map(to_one_or_zero, y_scores > .5))
        con_mat = tf.math.confusion_matrix(labels=y_test[:len(y_scores)], predictions=y_scores).numpy()
        sensi = con_mat[1][1]/(con_mat[1][1]+con_mat[1][0])
        speci = con_mat[0][0]/(con_mat[0][0]+con_mat[0][1])
        accu = (con_mat[0][0]+con_mat[1][1])/(con_mat[0][0]+con_mat[0][1]+con_mat[1][0]+con_mat[1][1])
        sen.append(sensi)
        spe.append(speci)
        acc.append(accu)
        print("sen : ",sensi)
        print("spe : ",speci)
        print("acc : ",accu)
        print(con_mat)  
        
    target_sen = np.mean(sen)
    target_spe = np.mean(spe)
    target_acc = np.mean(acc)
    sen_std = np.std(sen)
    spe_std = np.std(spe)
    acc_std = np.std(acc)
    print('target_sen:',target_sen, sen_std)
    print('target_spe:',target_spe, spe_std)
    print('target_acc:',target_acc, acc_std)
    target_sensitive.append(target_sen)
    target_specificity.append(target_spe)
    target_accuracy.append(target_acc)