### Imports

In [None]:
import torch
import scipy.io
import mne
import sklearn
import os 
import random
import time
import scipy.linalg
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
import lightgbm as lgb
import pickle
import time

from itertools import chain, product

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.cross_decomposition import CCA
from mne_features.feature_extraction import FeatureExtractor
from torch.utils.data import random_split, DataLoader, Dataset
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from scipy.stats import norm, wasserstein_distance
from torchmetrics.classification import BinaryAccuracy

%load_ext tensorboard


In [None]:
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

In [None]:
# Assess whether GPU is availble
if torch.cuda.is_available():
    print("PyTorch is using the GPU.")
    print("Device name - ", torch.cuda.get_device_name(torch.cuda.current_device()))
else: 
    print("PyTorch is not using the GPU.")
    

In [None]:
# Import Utility functions frmo diffrent notebooks
import import_ipynb
from IEEE_data import extract_ieee_data, LazyProperty, data_4class
from CHIST_ERA_data import *
from Utils import *

### Datset and Model classes

In [None]:
class convolution_AE(LightningModule):
    def __init__(self, input_channels, days_labels_N, task_labels_N, learning_rate=1e-3, filters_n = [32, 16, 4], mode = 'supervised'):
        super().__init__()
        self.input_channels = input_channels
        self.filters_n = filters_n
        self.learning_rate = learning_rate
        self.float()
        self.l1_filters, self.l2_filters, self.l3_filters = self.filters_n
        self.mode = mode
        self.switcher = False
        ### The model architecture ###
        

        # Encoder
        self.encoder = nn.Sequential(
        nn.Conv1d(self.input_channels, self.l1_filters, kernel_size=25, stride=5, padding=1),
#         nn.Dropout1d(p=0.2),
#         nn.MaxPool1d(kernel_size=15, stride=3),
        nn.LeakyReLU(),
#         nn.AvgPool1d(kernel_size=2, stride=2),
        nn.Conv1d(self.l1_filters, self.l2_filters, kernel_size=10, stride=2, padding=1),
#         nn.Dropout1d(p=0.2),
        nn.LeakyReLU(),
#         nn.AvgPool1d(kernel_size=2, stride=2),
        nn.Conv1d(self.l2_filters, self.l3_filters, kernel_size=5, stride=2, padding=1),
#         nn.Dropout1d(p=0.2),
        nn.LeakyReLU()
        )
                
        # Decoder
        self.decoder = nn.Sequential(
        # IMPORTENT - on the IEEE dataset - the output padding needs to be 1 in the row below -on CHIST-ERA its 1
        nn.ConvTranspose1d(self.l3_filters, self.l2_filters, kernel_size=5, stride=2, padding=1, output_padding=1),
#         nn.Dropout1d(p=0.33),
        nn.LeakyReLU(),
#         nn.Upsample(scale_factor=2, mode='linear'),
        nn.ConvTranspose1d(self.l2_filters, self.l1_filters, kernel_size=10, stride=2, padding=1, output_padding=0),
#         nn.Dropout1d(p=0.33),
        nn.LeakyReLU(),
#         nn.Upsample(scale_factor=2, mode='linear'),
        nn.ConvTranspose1d(self.l1_filters, self.input_channels, kernel_size=25, stride=5, padding=1, output_padding=0),
        )
        
        # Residuals Encoder
        self.res_encoder = nn.Sequential(
        nn.Conv1d(self.input_channels, self.l1_filters, kernel_size=25, stride=5, padding=1),
        nn.LeakyReLU(),
        nn.Conv1d(self.l1_filters, self.l2_filters, kernel_size=10, stride=2, padding=1),
        nn.LeakyReLU(),
        nn.Conv1d(self.l2_filters, self.l3_filters, kernel_size=5, stride=2, padding=1),
        nn.LeakyReLU()
        )
                
        # Classifier Days
        self.classiffier_days = nn.Sequential(
        nn.Flatten(),
        nn.Linear(1120, days_labels_N),
        nn.Dropout(0.5),
        )
        
        # Classifier Task
        self.classiffier_task = nn.Sequential(
        nn.Flatten(),
        nn.Linear(1120, task_labels_N),
        nn.Dropout(0.5),

        )
        
        
        
    def forward(self, x):
        # Forward through the layeres
        # Encoder
        x = self.encoder(x)

        # Decoder
        x = self.decoder(x)
        return x
    
    def encode(self, x):
        # Forward through the layeres
        # Encoder
        x = self.encoder(x)
        return x
    
    
    def on_train_epoch_end(self):
        if self.current_epoch > 450:
            self.unfreeze_decoder()
            self.unfreeze_encoder()
            self.mode = 'all'
    
        elif self.current_epoch % 20 == 0:
            self.switcher = not self.switcher
            if self.switcher == True:
                self.freeze_decoder()
                self.unfreeze_encoder()
                self.mode = 'task'
            elif self.switcher == False:
                self.freeze_encoder()
                self.unfreeze_decoder()
                self.mode = 'reconstruction'
        
    def training_step(self, batch, batch_idx):
        # Extract batch
        x, y, days_y = batch
        # Define loss functions
        loss_fn_days = nn.CrossEntropyLoss()
        loss_fn_rec = nn.MSELoss()
        loss_fn_task = nn.CrossEntropyLoss()
            
        # Encode
        encoded = self.encode(x)
        
        # Get predictions for task
        preds_task = self.classiffier_task(encoded)
        task_loss = loss_fn_task(preds_task, y)

        # Compute task classification accuracy
        task_acc = sklearn.metrics.accuracy_score(np.argmax(F.softmax(preds_task, dim=-1).detach().cpu().numpy(), axis=1),
                                             np.argmax(y.detach().cpu().numpy(), axis=1))

        # Log scalars
        self.log('task_loss', task_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('task_accuracy', task_acc, prog_bar=True, on_step=False, on_epoch=True)

        # Decode
        reconstructed = self.decoder(encoded)

        # Compute residuals
        residuals = torch.sub(x, reconstructed)

        # Encode residuals
        residuals_compact = self.res_encoder(residuals)

        # Get predictions per day
        preds_days = self.classiffier_days(residuals_compact)

        # Compute all losses
        days_loss = loss_fn_days(preds_days, days_y)
        reconstruction_loss = loss_fn_rec(reconstructed, x)

        # Compute days classification accuracy
        days_acc = sklearn.metrics.accuracy_score(np.argmax(F.softmax(preds_days, dim=-1).detach().cpu().numpy(), axis=1),
                                             np.argmax(days_y.detach().cpu().numpy(), axis=1))

        # Log results
        self.log('days_loss', days_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('reconstruction_loss', reconstruction_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('days_accuracy', days_acc, prog_bar=True, on_step=False, on_epoch=True)

        if self.mode == 'task':
            return task_loss + days_loss
        elif self.mode == 'reconstruction':
            return reconstruction_loss 
        elif self.mode == 'all':
            return reconstruction_loss + days_loss + task_loss
   
    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']
    
    
    def freeze_encoder(self):
        for name, param in self.encoder.named_parameters():
            param.requires_grad = False
            
    def unfreeze_encoder(self):
        for name, param in self.encoder.named_parameters():
            param.requires_grad = True
            
    def freeze_decoder(self):
        for name, param in self.decoder.named_parameters():
            param.requires_grad = False
            
    def unfreeze_decoder(self):
        for name, param in self.decoder.named_parameters():
            param.requires_grad = True
            
            
    def change_mode(self, mode):
        self.mode = mode
        
        
    def configure_optimizers(self):
        # Optimizer
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    

## Parameters

channels names:
['FC3', 'C1', 'C3', 'C5', 'CP3', 'O1', 'FC4', 'C2', 'C4', 'C6', 'CP4']

In [None]:
subID = '201' # As str 201, 205, 206
eyesFlag = 'CC' # str        CC --> closed,   OO --> open
dataDir = '../data'
results_data_dir = '../results_dir'
logdir = '../tb_logs'

# To get all The days in subject 201:
dayNumber = get_all_days(dataDir, subID, eyesFlag) # Array of the desired days number
dayNumber.sort()
# For subject 205 & 206 its better to insert range
# dayNumber = range(1,9)

# Subject 201 has only 1 block
block = [1]
trialLen = 6 # In seconds
filterLim = [1, 40] # In Hz
elec_idxs = range(11) # 0-10 according to channel names
train_days=[0,40]



ae_learning_rt = 3e-4
n_epochs = 250
batch_sz = 16
# If you want to use comparison rate - set layers_sz = False
convolution_filters = [8,16,32] # Length = 3

In [None]:
# Convert relative path to absolute path
dataDir = os.path.abspath(dataDir)


### Load the files - CHIST ERA

In [None]:
# Load all relevant days files into list
dataList = getRecording(dataDir, subID, eyesFlag, dayNumber, block)

# Extract and segment all the data
dictList = []
for dayData in dataList:
    # Extract each day data
    interData = extractData(dayData)
    
    # This condition is to remove some corrupted files in subject 201
    if interData['EEG'].dtype != np.dtype('float64'):
        continue
        
    # Filter the data
    interData['EEG'] = eegFilters(interData['EEG'], interData['fs'], filterLim)
    interData['EEG'] = interData['EEG'][elec_idxs, :]

    # Segment the data
    dictList.append(segmentEEG(interData, trialLen, printFlag=0))

# Stack block of same day
dictListStacked = stackBlocks(dictList, len(block))


# Training loop function

In [None]:
def training_loop(train_days, dictListStacked, ae_learning_rt, convolution_filters, batch_sz, epoch_n):
    
    device = torch.device("cuda")
    # Logger
    logger = TensorBoardLogger('../tb_logs', name='EEG_Logger')
    # Shuffle the days
    random.shuffle(dictListStacked)
    # Train Dataset
    signal_data = EEGDataSet_signal_by_day(dictListStacked, train_days)
    signal_data_loader = DataLoader(dataset=signal_data, batch_size=batch_sz, shuffle=True, num_workers=0)
    x, y, days_y = signal_data.getAllItems()
    y = np.argmax(y, -1)
    days_labels_N = signal_data.days_labels_N
    task_labels_N = signal_data.task_labels_N

    # Train model on training day
    metrics = ['classification_loss', 'reconstruction_loss']
    day_zero_AE = convolution_AE(signal_data.n_channels, days_labels_N, task_labels_N, ae_learning_rt, filters_n=convolution_filters, mode='supervised')
    day_zero_AE.to(device)

    trainer_2 = pl.Trainer(max_epochs=epoch_n, logger=logger, accelerator='gpu', devices=-1)
    trainer_2.fit(day_zero_AE, train_dataloaders=signal_data_loader)
    
    # CV On the training set (with and without ae)
    score_ae, day_zero_AE_clf = csp_score(np.float64(day_zero_AE(x).detach().numpy()), y, cv_N=5, classifier=False)
    score_bench, day_zero_bench_clf = csp_score(np.float64(x.detach().numpy()), y, cv_N=5, classifier=False)
    
    # Loop :)
 
    # Append ws for normal method
    bench_same_day_score = score_bench
    # Append ws with ae
    AE_same_day_score = score_ae

    test_days = [train_days[1], len(dictListStacked)]

    # Create test Datasets
    signal_test_data = EEGDataSet_signal(dictListStacked, test_days)

    # get data
    signal_test, y_test = signal_test_data.getAllItems()
    # reconstruct EEG using day 0 AE
    rec_signal_zero = day_zero_AE(signal_test).detach().numpy()


    # Use models
    # within session cv on the test set (mean on test set)
    ws_test, _ = csp_score(np.float64(signal_test.detach().numpy()), y_test, cv_N=5, classifier = False)
    # Using day 0 classifier for test set inference (mean on test set)
    bs_test = csp_score(np.float64(signal_test.detach().numpy()), y_test, cv_N=5, classifier=day_zero_bench_clf)
    # Using day 0 classifier + AE for test set inference (mean on test set)
    bs_ae_test = csp_score(rec_signal_zero, y_test, cv_N=5, classifier=day_zero_AE_clf)
    
    return bench_same_day_score, AE_same_day_score, ws_test, bs_test, bs_ae_test, day_zero_AE

In [None]:
per_day_score = []
timestr = time.strftime("%Y%m%d-%H%M%S")

for d in dictListStacked:
    temp_score = []
    for i in range(100):
        score_bench, _ = csp_score(np.float64(d['segmentedEEG']), d['labels'], cv_N=5, classifier=False)
        temp_score.append(score_bench)
    per_day_score.append(np.mean(temp_score))

with open(results_data_dir+'/per_day_score_' + timestr + '.pickle', 'wb') as f:
    pickle.dump(per_day_score, f)

In [None]:
# Sub 206 - 200 epochs
train_days=[0,40]

bench_same_day_score, bench_diff_day_score, AE_diff_day_score, model = \
training_loop(train_days, dictListStacked, dictListStacked[0]['fs'], ae_learning_rt, convolution_filters, batch_sz, n_epochs)

# Training for several days loop function

In [None]:
def score_over_number_of_days(start_day, epoch_n, dictListStacked, fs, ae_learning_rt, \
                              convolution_filters, batch_sz, max_delta=30):
    bench_diff_day_score_mean = []
    AE_diff_day_score_mean = []
    bench_same_day_score_mean = []
    
    ws_train_score = []
    ae_train_score = []
    
    for delta in range(1, len(dictListStacked) - start_day):
        
        if delta > max_delta:
            break
        
        train_days=[start_day, start_day + delta]

        device = torch.device("cuda")
        # Logger
        logger = TensorBoardLogger('../tb_logs', name='EEG_Logger')
        # Shuffle the days
        random.shuffle(dictListStacked)
        # Train Dataset
        signal_data = EEGDataSet_signal_by_day(dictListStacked, train_days)
        signal_data_loader = DataLoader(dataset=signal_data, batch_size=batch_sz, shuffle=True, num_workers=0)
        x, y, days_y = signal_data.getAllItems()
        y = np.argmax(y, -1)
        days_labels_N = signal_data.days_labels_N
        task_labels_N = signal_data.task_labels_N

        
        # Train model on training day
        metrics = ['classification_loss', 'reconstruction_loss']
        day_zero_AE = convolution_AE(signal_data.n_channels, days_labels_N, task_labels_N, ae_learning_rt,\
                                     filters_n=convolution_filters, mode='supervised')
        day_zero_AE.to(device)

        trainer_2 = pl.Trainer(max_epochs=epoch_n, logger=logger, accelerator='gpu', devices=-1)
        trainer_2.fit(day_zero_AE, train_dataloaders=signal_data_loader)

        score_ae, day_zero_AE_clf = csp_score(np.float64(day_zero_AE(x).detach().numpy()), y, cv_N=5, classifier=False)
        score_bench, day_zero_bench_clf = csp_score(np.float64(x.detach().numpy()), y, cv_N=5, classifier=False)
        
        ws_train_score.append(score_bench)
        ae_train_score.append(score_ae)      

        # Create test Datasets
        signal_test_data = EEGDataSet_signal(dictListStacked, [train_days[1], len(dictListStacked)])
        signal_test_data_loader = DataLoader(dataset=signal_test_data, batch_size=8, shuffle=True, num_workers=0)

        # get data
        signal_test, y_test = signal_test_data.getAllItems()
        # reconstruct EEG using day 0 AE
        rec_signal_zero = day_zero_AE(signal_test).detach().numpy()


        # Use models
        bench_diff_day = csp_score(np.float64(signal_test.detach().numpy()), y_test, cv_N = 5, classifier = day_zero_bench_clf)
        AE_diff_day = csp_score(np.float64(rec_signal_zero), y_test, cv_N = 5, classifier = day_zero_AE_clf)


        # Rest of the days cross validation score
        score_bench, _= csp_score(np.float64(signal_test.detach().numpy()), y_test, cv_N = 5, classifier = False)
        
        # Append means
        bench_diff_day_score_mean.append(bench_diff_day)
        AE_diff_day_score_mean.append(AE_diff_day)
        bench_same_day_score_mean.append(score_bench)
       
    # Convert results to numpy
    bench_same_day_score_mean = np.asarray(bench_same_day_score_mean)
    bench_diff_day_score_mean = np.asarray(bench_diff_day_score_mean)
    AE_diff_day_score_mean = np.asarray(AE_diff_day_score_mean)

    
    # Return results
    return bench_same_day_score_mean, bench_diff_day_score_mean, AE_diff_day_score_mean,\
            ae_train_score, ws_train_score

In [None]:
# Start from which day to plot?
plot_from = 1

# Plot
plt.plot(range(plot_from, plot_from + len(AE_diff_day_score[plot_from:])), AE_diff_day_score[plot_from:], label='AE diff day', color='g')
plt.plot(range(plot_from, plot_from + len(AE_diff_day_score[plot_from:])), bench_diff_day_score[plot_from:], label='bench diff day', color='r')
plt.plot(range(plot_from, plot_from + len(AE_diff_day_score[plot_from:])), bench_same_day_score[plot_from:], label='bench same day', color='b')

plt.axhline(y=np.mean(AE_diff_day_score[plot_from:]), color='g', linestyle='--')
plt.axhline(y=np.mean(bench_diff_day_score[plot_from:]), color='r', linestyle='--')
plt.axhline(y=np.mean(bench_same_day_score[plot_from:]), color='b', linestyle='--')

plt.title('Accuracy Over Days - Using Day 0 Classifier')
plt.xlabel('Day #')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
def score_over_number_of_days_train_cv(start_day, epoch_n, dictListStacked, fs, ae_learning_rt, \
                              convolution_filters, batch_sz, max_delta=30):
    ws_train_list = []
    
    for delta in range(1, len(dictListStacked) - start_day):
        
        if delta > max_delta:
            break
        
        train_days=[start_day, start_day + delta]
        # Train Dataset
        random.shuffle(dictListStacked)
        signal_data = EEGDataSet_signal_by_day(dictListStacked, train_days)
        signal_data_loader = DataLoader(dataset=signal_data, batch_size=batch_sz, shuffle=True, num_workers=0)
        x, y, days_y = signal_data.getAllItems()
        y = np.argmax(y, -1)


        score_bench, day_zero_bench_clf = csp_score(np.float64(x.detach().numpy()), y, cv_N=5, classifier=False)
        
        
        ws_train_list.append(score_bench)

    # Return results
    return ws_train_list

In [None]:
signal_data = EEGDataSet_signal(dictListStacked, [3,4])
signal_data_loader = DataLoader(dataset=signal_data, batch_size=batch_sz, shuffle=True, num_workers=0)
plotSignal(0, day_zero_AE, signal_data_loader)

## Several days realizations (long time)

In [None]:
os.path.abspath(logdir)

In [None]:
ws_list = []
bs_list = []
ae_list = []
ws_train_list = []
ae_train_list = []

# Get current time for file name
timestr = time.strftime("%Y%m%d-%H%M%S")

for i in range(50):
    print('Iter: ', i)
    bench_same_day_score_mean, bench_diff_day_score_mean, AE_diff_day_score_mean,\
    ws_train_score, ae_train_score = \
    score_over_number_of_days(0, 500, dictListStacked, 128, ae_learning_rt,\
                              convolution_filters, batch_sz, max_delta=130)
    ws_list.append(bench_same_day_score_mean)
    bs_list.append(bench_diff_day_score_mean)
    ae_list.append(AE_diff_day_score_mean)
    ws_train_list.append(ws_train_score)
    ae_train_list.append(ae_train_score)
    
    # Each iteration save locally the results
    save_obj = (ws_list, bs_list, ae_list, ws_train_list, ae_train_list)
    # Save the lists
    with open(results_data_dir+'/201_results_500_epoch' + timestr + '.pickle', 'wb') as f:
        pickle.dump(save_obj, f)

In [None]:
ws_list_all = []
bs_list_all = []
ae_list_all = []
ws_train_all = []
ae_train_all = []
# load all pickle files
for file in os.listdir(results_data_dir):
    if file.endswith(".pickle"):
        if 'new' in file and 'uns' not in file:
#         if 'uns' in file:
            with open(results_data_dir + '/' + file, "rb") as f:
                load_obj = pickle.load(f)
                ws_list_all.append(load_obj[0])
                bs_list_all.append(load_obj[1])
                ae_list_all.append(load_obj[2])
                ws_train_all.append(load_obj[3])
                ae_train_all.append(load_obj[4])
        if 'per_day_score' in file:
            with open(results_data_dir + '/' + file, "rb") as f:
                per_day_score = pickle.load(f)


#  Flatten the lists
ws_means = np.mean(np.asarray([j for i in ws_list_all for j in i]), axis=0)
bs_means = np.mean(np.asarray([j for i in bs_list_all for j in i]), axis=0)
ae_means = np.mean(np.asarray([j for i in ae_list_all for j in i]), axis=0)
ws_train_means = np.mean(np.asarray([j for i in ws_train_all for j in i]), axis=0)
ae_train_means = np.mean(np.asarray([j for i in ae_train_all for j in i]), axis=0)

In [None]:
np.asarray([j for i in ws_train_all for j in i]).shape

In [None]:
x_ax = range(1,1+len(ws_means))
# Plots Results
plt.plot(x_ax, ws_train_means, label='WS train', color='b')
plt.plot(x_ax, ae_train_means, label='AE train', color='g')

plt.axhline(0.5, label='Chance level', color='k', linestyle='--')

# Figure stuff
plt.title('Mean Accuracy Score Over Days As Function Of Number Of Training Days')
plt.xlabel('Number of Training Days')
plt.ylabel('Mean Accuracy')
plt.legend()
plt.show()

In [None]:
x_ax = range(1,1+len(ws_means))
# Plots Results
plt.plot(x_ax, ae_means, label='AE score', color='g')
plt.plot(x_ax, bs_means, label='BS score', color='r')
# plt.plot(x_ax, ws_means, label='WS on test', color='teal')
plt.plot(x_ax, ws_train_means, label='WS on train', color='b')
plt.axhline(np.mean(per_day_score), label='Mean per day score', color='orange', linestyle='--')
plt.axhline(0.5, label='Chance level', color='k', linestyle='--')
# Figure stuff
plt.title('Mean Accuracy Score Over Days As Function Of Number Of Training Days')
plt.xlabel('Number of Training Days')
plt.ylabel('Mean Accuracy')
plt.legend()
plt.show()

In [None]:
ae_means_uns = ae_means

In [None]:
plt.plot(x_ax[:100], ae_means[:100], label='AE score', color='g')
plt.plot(x_ax[:100], ae_means_uns, label='AE score', color='b')


In [None]:
# import pickle
# import time
# timestr = time.strftime("%Y%m%d-%H%M%S")

# save_obj = (ws_list, bs_list, ae_list)
# # Save the lists
# with open('201_results' + timestr + '.pickle', 'wb') as f:
#     pickle.dump(save_obj, f)

# # Load the lists
# with open('201_results' + timestr + '.pickle', 'rb') as f:
#     loaded_obj = pickle.load(f)    


In [None]:
ws_means = np.mean(np.asarray(ws_list), axis=0)
bs_means = np.mean(np.asarray(bs_list), axis=0)
ae_means = np.mean(np.asarray(ae_list), axis=0)


In [None]:
trials_n_per_day = [len(day_dict['labels']) for day_dict in dictListStacked]

print('Total number of days- ', len(dictListStacked), '\nMean Trials count per day- ', np.mean(trials_n_per_day), '\nTrials count std- ', np.std(trials_n_per_day))