In [35]:
import glob, os
from IPython.display import clear_output
import numpy as np

files = []
for file in glob.glob("/Users/mohamedr/projects/epilespy_connectivity/isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_epilepsy/v2.0.0/*/*/"):
    files.append(file)
print(len(files))
    
from sklearn.model_selection import train_test_split

train_files, test_files = train_test_split(files, test_size=0.1, random_state=2025)
print("num subjects:", len(train_files), len(test_files))

files = []
for folder in test_files:
    for file in glob.glob(folder+"/*/*/*.edf"):
        files.append(file)
print("num test edf:",  len(files))
test_files = files

files = []
for folder in train_files:
    for file in glob.glob(folder+"/*/*/*.edf"):
        files.append(file)
print("num train edf:", len(files))
train_files = files

np.random.shuffle(train_files)

200
num subjects: 180 20
num test edf: 135
num train edf: 2163


In [9]:
"""
TUH Epilipsy Dataset
"""

import glob, os
import numpy as np
from tqdm import tqdm
import torch

from IPython.display import clear_output

#from src.model import EEGModel
from src.read_data import build_data
from src.make_features import train_test, standardize_data, data_loader
#from src.train import train_model, print_acc

from torch_geometric.data import Data, TemporalData, HeteroData
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dense_to_sparse

def norm_adj(train_graphs, test_graphs):
    for i in range(train_graphs.shape[0]):
        for j in range(train_graphs.shape[1]):
                min_ = (train_graphs[i, j, :, :]).min()
                max_ = (train_graphs[i, j, :, :]).max()
                train_graphs[i, j, :,  :] = (train_graphs[i, j, :,  :] - min_)/(max_ - min_)
                
    for i in range(test_graphs.shape[0]):
        for j in range(test_graphs.shape[1]):
            min_ = (test_graphs[i, j, :, :]).min()
            max_ = (test_graphs[i, j, :, :]).max()
            test_graphs[i, j, :, :] = (test_graphs[i, j, :,  :] - min_)/(max_ - min_)
            
    return train_graphs, test_graphs
    
SAVED_DATA = True
SAVED_FEATURES = False
CONN_TYPES = ["coh"]#, "ciplv", "coh", "pc"]
NUM_EPOCHS = 30 #This is based on Cross validation experiments
BATCH_SIZE = 128
DEVICE = torch.device("cpu") #torch.device("mps")

for conn in CONN_TYPES:    
    print(conn)
    if SAVED_FEATURES:
        train_X = np.load("../saved_npy_ep/features/train_X.npy") #samples, #epochs, #channels, #time points)
        train_y = np.load("../saved_npy_ep/features/train_y.npy")
        train_graphs = np.load("../saved_npy_ep/features/train_graphs_"+conn+".npy")
        test_y = np.load("../saved_npy_ep/features/test_y.npy")
        test_X = np.load("../saved_npy_ep/features/test_X.npy")
        test_graphs = np.load("../saved_npy_ep/features/test_graphs_"+conn+".npy")
        train_X = np.moveaxis(train_X, 1, -1)
        test_X = np.moveaxis(test_X, 1, -1)
        train_graphs, test_graphs = norm_adj(train_graphs, test_graphs)
        #train_X, test_X = standardize_data(train_X, test_X)
        
    else:
        if SAVED_DATA:
            # Saved arrays on disk
            train_X_files = ["../saved_npy_ep/features/train_X"+str(i)+".npy" for i in range(0, len(train_files), 100)]
            train_y_files = ["../saved_npy_ep/features/train_y"+str(i)+".npy" for i in range(0, len(train_files), 100)]
            train_X = np.vstack([np.load(file).astype(np.float16) for i, file in enumerate(train_X_files)])
            train_y = np.vstack([np.load(file).astype(np.float16) for i, file in enumerate(train_y_files)])
            test_X = np.load("../saved_npy_ep/features/test_X.npy").astype(np.float16)
            test_y = np.load("../saved_npy_ep/features/test_y.npy").astype(np.float16)
            
        else:
            num_windows = 60
            window_size = 200
            for i in range(0, len(train_files), 100):
                train_X, train_y = build_data(train_files[i:i+100], use_windows=False, 
                                              window_size=window_size, num_windows=num_windows,
                                              dataset = "epilepsy")
                np.save("../saved_npy_ep/features/train_X"+str(i)+".npy", train_X.astype(np.float16))
                np.save("../saved_npy_ep/features/train_y"+str(i)+".npy", train_y.astype(np.float16))
            test_X, test_y = build_data(test_files, use_windows=False, window_size=window_size, 
                                        num_windows=num_windows, dataset = "epilepsy")
            np.save("../saved_npy_ep/features/test_X.npy", test_X.astype(np.float16))
            np.save("../saved_npy_ep/features/test_y.npy", test_y.astype(np.float16))
            clear_output()
    
            train_X_files = ["../saved_npy_ep/features/train_X"+str(i)+".npy" for i in range(0, len(train_files), 100)]
            train_y_files = ["../saved_npy_ep/features/train_y"+str(i)+".npy" for i in range(0, len(train_files), 100)]
            train_X = np.vstack([np.load(file) for file in train_X_files])
            train_y = np.vstack([np.load(file) for file in train_y_files])
            test_X = np.load("../saved_npy_ep/features/test_X.npy")
            test_y = np.load("../saved_npy_ep/features/test_y.npy")

        np.save("../saved_npy_ep/features/train_X.npy", train_X)
        np.save("../saved_npy_ep/features/train_y.npy", train_y)
        np.save("../saved_npy_ep/features/test_y.npy", test_y)
        np.save("../saved_npy_ep/features/test_X.npy", test_X)
        print("make features")
        _ , train_graphs, _, _ , test_graphs, _ = train_test(train_X=train_X, 
                                                             test_X=test_X, 
                                                             train_y=train_y, 
                                                             test_y=test_y, 
                                                             method=conn)
        
        np.save("../saved_npy_ep/features/train_X.npy", train_X)
        np.save("../saved_npy_ep/features/train_y.npy", train_y)
        np.save("../saved_npy_ep/features/train_graphs_"+conn+".npy", train_graphs)
        np.save("../saved_npy_ep/features/test_y.npy", test_y)
        np.save("../saved_npy_ep/features/test_X.npy", test_X)
        np.save("../saved_npy_ep/features/test_graphs_"+conn+".npy", test_graphs)

coh
make features
read_data
Gen features
calculating connectivity


100%|██████████| 1871/1871 [1:16:54<00:00,  2.47s/it]


Gen features
calculating connectivity


100%|██████████| 121/121 [05:05<00:00,  2.52s/it]


In [32]:
x = list(train_y.squeeze())
x.count(1), x.count(0)

(1518, 353)

In [30]:
!ls ../

[34m__MACOSX[m[m                        [34mmci_ad_dataset_npy[m[m
[34m__pycache__[m[m                     [34mmci_dem_dataset[m[m
AD_mci_dataset.ipynb            model.py
chrononet_pytorch.ipynb         [34mmultiedge[m[m
[34mclean_code[m[m                      node2vec_tuh.ipynb
clean_code.zip                  openneuro.ipynb
corr-caueeg-MCI_chrononet.ipynb plot.py
corr-caueeg-MCI.ipynb           process.py
dementia1.ipynb                 read_data.py
evaluate.py                     [34msaved_models[m[m
explain.py                      [34msaved_npy_ep[m[m
[34meye_tracking[m[m                    [34msaved_npy_tuh[m[m
[34mgithub_repo[m[m                     [34mseizure_raw[m[m
[34mgraphs[m[m                          tgn.ipynb


In [17]:
import numpy as np
train_X = np.load("saved_npy_tuh/train_X.npy")[:10]
test_X = np.load("saved_npy_tuh/test_X.npy")[:10]
train_y = np.load("saved_npy_tuh/train_y.npy")
test_y = np.load("saved_npy_tuh/test_y.npy")

In [35]:
import neurokit2 as nk
from tqdm import tqdm
def cal_ent(signals):
    all_entropy = []
    for ch_signals in signals:
        entropy = []
        for period_signal in ch_signals:
            maxen = nk.entropy_maximum(period_signal)[0]
            diff = nk.entropy_differential(period_signal)[0]
            power = nk.entropy_power(period_signal)[0]
            tsa =  nk.entropy_tsallis(period_signal, q=1)[0]
            shan = nk.entropy_shannon(period_signal, base=np.e)[0]
            appr = nk.entropy_approximate(period_signal)[0]
            #appr = nk.entropy_approximate(period_signal)[0]
            spec = nk.entropy_spectral(period_signal, show=False)[0]
            entropies = [shan, tsa]
            entropy.append(entropies)
        all_entropy.append(entropy)

    return all_entropy

for x in tqdm(train_X):
    maxen = cal_ent(x)


80%|███████████████████████████████████▏        | 8/10 [00:14<00:03,  1.84s/it]

KeyboardInterrupt: 

In [37]:
15/10*2717/60

67.925

In [2]:
import mne
from tqdm import tqdm
from statsmodels.tsa.stattools import grangercausalitytests
import random
import numpy as np
from scipy.signal import savgol_filter, detrend
from scipy import signal

import numpy as np
import scipy.signal as sig

from braindecode.preprocessing import exponential_moving_standardize

from sklearn.decomposition import FastICA

def get_label(file):
    label = file.split("/")[-5]
    if label == "no_epilepsy_edf":
        return [0]
    elif label == "epilepsy_edf":
        return [1]

def multichannel_sliding_window(X, size, step):
    shape = (X.shape[0] - X.shape[0] + 1, (X.shape[1] - size + 1) // step, X.shape[0], size)
    strides = (X.strides[0], X.strides[1] * step, X.strides[0], X.strides[1])
    return np.lib.stride_tricks.as_strided(X, shape, strides)[0]

    
def read_edf_file(file, use_windows=True, num_windows=100): 
    
    try:
        data_raw = mne.io.read_raw_edf(file, preload=True)
        channels_to_use = ["FP1", "FP2", "F7", "F3", "FZ", "F4", "F8",
                                  "T3", "C3", "CZ", "C4", "T4", "T5",
                                  "P3", "PZ", "P4", "T6", "O1", "O2"]
        ch_name_update_func = lambda ch: ch.split(' ')[-1].split('-')[0]
        data_raw.rename_channels(mapping=ch_name_update_func)
        data_raw = data_raw.pick_channels(channels_to_use, ordered=True)
        montage = mne.channels.make_standard_montage('standard_1020')
        data_raw.set_montage(montage, match_case=False, match_alias=True)
        data_raw.filter(l_freq=1.0, h_freq=45.0)
        data_raw.resample(100, npad='auto')
        data_raw = data_raw.get_data()
        #data_raw = np.diff(data_raw, axis=-1)
        data_raw = multichannel_sliding_window(data_raw, 100, 90)
    except:
        return 
    startpoint = 10
    
    if data_raw.shape[0] >= num_windows+startpoint:
        return [data_raw[startpoint:num_windows+startpoint, :, :]]
    else:
        return


def build_data(raw_data, use_windows=True, num_windows=200):
    
    all_data_features = []
    data_labels = []
    data_graphs = []
    
    for file in tqdm(raw_data):
        edf_data = read_edf_file(file, use_windows=use_windows, num_windows=num_windows)
        if not edf_data:
            continue
        else:
            for edf_data1 in edf_data:
                all_data_features.append(edf_data1)
                label = get_label(file)
                data_labels.append(np.array(label))
    
    all_data_features = np.array(all_data_features)
    data_labels = np.array(data_labels)
             
    return all_data_features, data_labels

In [3]:
import copy

from sklearn.preprocessing import OneHotEncoder
#from train import trainer
import torch
#from tgcn import TGCN
from torch.utils.data import TensorDataset,DataLoader
from scipy.stats import pearsonr

from numpy.lib.stride_tricks import sliding_window_view
from scipy import signal
from statsmodels.tsa.stattools import grangercausalitytests
from tqdm import tqdm
import itertools
from scipy.sparse import csgraph
from sklearn.feature_selection import mutual_info_regression
import spkit as sp


def hilphase(x1,x2):
    sig1_hill=sig.hilbert(x1)
    sig2_hill=sig.hilbert(x2)
    pdt=(np.inner(sig1_hill,np.conj(sig2_hill))/(np.sqrt(np.inner(sig1_hill,
               np.conj(sig1_hill))*np.inner(sig2_hill,np.conj(sig2_hill)))))
    phase = np.angle(pdt)
    return phase
    
def gc(x1, x2):
    X = np.vstack([x1, x2]).T
    gc = grangercausalitytests(X, [2], addconst=True, verbose=False)[2][0]['ssr_ftest'][1]
    return gc
    
    
# Coherence - δ
def coherence(eegData,fs):
    coh_res = []
    for ii, jj in itertools.combinations(range(eegData.shape[0]), 2):
        coh_res.append(CoherenceDelta(eegData, ii, jj, fs=fs))
    coh_res = np.array(coh_res)
    return coh_res

# Mutual information
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics import mutual_info_score

# https://github.com/ufvceiec/EEGRAPH/blob/develop-refactor/eegraph/strategy.py
def pli(data_intervals, i, j):
    sig1_phase = signal.hilbert([data_intervals[i]])
    sig2_phase = signal.hilbert([data_intervals[j]])
    phase_diff = sig1_phase[0] - sig2_phase[0]
    phase_diff = phase_diff.astype(float)
    phase_diff = (phase_diff  + np.pi) % (2 * np.pi) - np.pi
    pli = abs(np.mean(np.sign(phase_diff)))
    return pli
    

# Cross correlation 
# https://github.com/ufvceiec/EEGRAPH/blob/develop-refactor/eegraph/strategy.py
def calculate_cc(data_intervals, i, j):
    x = data_intervals[i]
    y = data_intervals[j]
    
    Rxy = signal.correlate(x,y, 'full')
    Rxx = signal.correlate(x,x, 'full')
    Ryy = signal.correlate(y,y, 'full')
    
    lags = np.arange(-len(data_intervals[i]) + 1, len(data_intervals[i]))
    lag_0 = int((np.where(lags==0))[0])

    Rxx_0 = Rxx[lag_0]
    Ryy_0 = Ryy[lag_0]
    
    Rxy_norm = (1/(np.sqrt(Rxx_0*Ryy_0)))* Rxy
    #We use the mean from lag 0 to a 10% displacement. 
    disp = round((len(data_intervals[i])) * 0.10)
    cc_coef = Rxy_norm[lag_0: lag_0 + disp].mean()
    return cc_coef
     
def cal_mi(x1, x2):
    x1 = x1.reshape(x1.shape[0], 1)
    mi = mutual_info_regression(x1, x2)
    return mi


# Mutual information
def calculate2Chan_MI(eegData,ii,jj,bin_min=-200, bin_max=200, binWidth=2):
    H = np.zeros(eegData.shape[2])
    bins = np.arange(bin_min+1, bin_max, binWidth)
    for epoch in range(eegData.shape[2]):
        c_xy = np.histogram2d(eegData[ii,:,epoch],eegData[jj,:,epoch],bins)[0]
        H[epoch] = mutual_info_score(None, None, contingency=c_xy)
    return H


import spkit as sp

def gen_graphs(eegs, num_nodes=19, cal_conn="cc"):
    #eegs(snapshots, bands, timpoints)
    c = []
    for i in range(num_nodes):
        c1 = []
        for j in range(num_nodes):
            if cal_conn == "pearson":
                conn = pearsonr(eegs[i], eegs[j])[0]
            elif cal_conn == "cc":
                conn = calculate_cc(eegs, i, j)
            elif cal_conn == "plv": 
                conn = hilphase(eegs[i], eegs[j])
            elif cal_conn == "pli":
                conn = pli(eegs, i, j)
            elif cal_conn == "gc": 
                conn = gc(eegs[i], eegs[j])
            elif cal_conn == "mi": 
                conn = sp.mutual_info(eegs[i], eegs[j])
            elif cal_conn == "con-entropy": 
                conn = sp.entropy_cond(eegs[i],eegs[j])
            elif cal_conn == "cross-entropy": 
                conn = sp.entropy_cross(eegs[i],eegs[j])
            elif cal_conn == "kld-entropy": 
                conn = sp.entropy_kld(eegs[i],eegs[j])
            elif cal_conn == "joint-entropy": 
                conn = sp.entropy_joint(eegs[i],eegs[j])
            c1.append(conn)
        c.append(c1)
    return c


def gen_features(X, y, device, cal_conn, window_size=100, overlap=0, augment=False, stft=False):
    X_new = np.moveaxis(np.array(X), 1, -1)
    graphs = []
    dynamic_range = X_new.shape[-1]
    threshold = 0.3
    
    print("calculating connectivity")
    for x in tqdm(X_new):
        temp = []
        for i in range(dynamic_range):
            temp_g = gen_graphs(x[:, :, i], cal_conn=cal_conn)
            temp_g = np.array(temp_g).squeeze()
            #temp_g = (temp_g - temp_g.min())/(temp_g.max() - temp_g.min())
            #temp_g[temp_g<threshold] = 0
            #temp_g = csgraph.laplacian(temp_g)
            temp.append(temp_g)
        graphs.append(temp)
        
    
    # calculate STFT
    X_new = np.moveaxis(X_new, -1, 1)
    X = []
    for x in X_new:
        f, t, Zxx = signal.stft(x, 100, nperseg=100, noverlap=10, boundary=None, padded=None)
        X.append(np.abs(Zxx))
        #X.append(x)
        
    X_new = np.array(X).squeeze()
    X_new = np.moveaxis(np.array(X_new), 1, -1)    
    graphs = np.array(graphs)
    
    graphs = np.moveaxis(graphs.squeeze(), 1, -1)
        
    return X_new, graphs, y

import copy
def standardize_data(train_X, test_X):
    train_X_std = copy.deepcopy(train_X)
    test_X_std = copy.deepcopy(test_X)
    for i in range(train_X.shape[2]):
        min_ = np.min(train_X[:, :, i, :])
        max_ = np.max(train_X[:, :, i, :])
        train_X_std[:, :, i, :] = (train_X[:, :, i, :] - min_)/(max_ - min_)
        test_X_std[:, :, i, :] = (test_X[:, :, i, :] - min_)/(max_ - min_)
    return train_X_std, test_X_std

In [4]:
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_max_pool
import torch.nn as nn
from torch_scatter import scatter_add
from torch_geometric.utils import add_self_loops

class TGCN(torch.nn.Module):
    r"""An implementation of the Temporal Graph Convolutional Gated Recurrent Cell.
    For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
    Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        improved (bool): Stronger self loops. Default is False.
        cached (bool): Caching the message weights. Default is False.
        add_self_loops (bool): Adding self-loops for smoothing. Default is True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True,
    ):
        super(TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_z = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_reset_gate_parameters_and_layers(self):

        self.conv_r = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_h = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=1)
        Z = self.linear_z(Z)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=1)
        R = self.linear_r(R)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=1)
        H_tilde = self.linear_h(H_tilde)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H
    
    
    
class A3TGCN(torch.nn.Module):
    r"""An implementation of the Attention Temporal Graph Convolutional Cell.
    For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
    Network for Traffic Forecasting." <https://arxiv.org/abs/2006.11583>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        periods (int): Number of time periods.
        improved (bool): Stronger self loops (default :obj:`False`).
        cached (bool): Caching the message weights (default :obj:`False`).
        add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        periods: int,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True
    ):
        super(A3TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.periods = periods
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self._setup_layers()

    def _setup_layers(self):
        self._base_tgcn = TGCN(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
        torch.nn.init.uniform_(self.attention)
        

    def forward(
        self,
        X: torch.FloatTensor,
        A: torch.FloatTensor,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor): Node features for T time periods.
            * **edge_index** (PyTorch Long Tensor): Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
            * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
        """
        
        H_accum = 0
        probs = torch.nn.functional.softmax(self.attention, dim=0)
        for period in range(self.periods):
            Xt = X[:, :, :, period]
            batch_size = Xt.shape[0]
            Xt = Xt.reshape(Xt.shape[0]*Xt.shape[1], Xt.shape[-1])
            At = A[:, :, :, period]
            At = torch.block_diag(*At)
            idx = (At > 0).nonzero().t().contiguous().long().to(X.device)
            row, col = idx
            w = At[row, col].float().to(X.device)            
            temp_emb = self._base_tgcn(Xt, idx, w, H)
            temp_emb = temp_emb.reshape(batch_size, 19, 32)
            
            H_accum = H_accum + probs[period] * temp_emb #([32, 207, 32]
            
        return H_accum
    
class EEGModel(nn.Module):
    def __init__(self, num_nodes, node_features, num_classes, num_windows, device):
        super(EEGModel, self).__init__()
        self.forwardA3TGCN = A3TGCN(in_channels=node_features, out_channels=32, periods=num_windows) # node_features=2, periods=12
        self.backwardA3TGCN = A3TGCN(in_channels=node_features, out_channels=32, periods=num_windows) # node_features=2, periods=12
        self.num_nodes = num_nodes
        self.BN = nn.BatchNorm1d(self.num_nodes)
        self.num_windows = num_windows
        self.fc2 = nn.Linear(self.num_nodes*64, num_classes)
        
    def forward(self, X, A):
        HS1 = self.forwardA3TGCN(X, A)
        X_flip = torch.flip(X, dims=[1])
        A_flip = torch.flip(A, dims=[1])
        HS2 = self.backwardA3TGCN(X_flip, A_flip)
        HS = torch.cat((HS1, HS2), -1)
        HS = nn.functional.relu(HS)
        HS = self.BN(HS)
        HS = HS.reshape(HS.shape[0], self.num_nodes*64)
        out = self.fc2(HS)
        return out

In [5]:
from IPython.display import clear_output
from torch.utils.data import TensorDataset,DataLoader
from sklearn.preprocessing import OneHotEncoder
import copy


def standardize_data(train_X, test_X):
    train_X_std = copy.deepcopy(train_X)
    test_X_std = copy.deepcopy(test_X)
    for i in range(train_X.shape[1]):
        min_ = np.min(train_X[:, i, :, :])
        max_ = np.max(train_X[:, i, :, :])
        train_X_std[:, i, :, :] = (train_X[:, i, :, :] - min_)/(max_ - min_)
        test_X_std[:, i, :, :] = (test_X[:, i, :, :] - min_)/(max_ - min_)
    return train_X_std, test_X_std


def standardize_data(train_X, test_X):
    train_X_std = copy.deepcopy(train_X)
    test_X_std = copy.deepcopy(test_X)
    for i in range(train_X.shape[1]):
        for j in range(train_X.shape[2]):
            min_ = np.min(train_X[:, i, j, :])
            max_ = np.max(train_X[:, i, j, :])
            train_X_std[:, i, j, :] = (train_X[:, i, j, :] - min_)/(max_ - min_)
            test_X_std[:, i, j, :] = (test_X[:, i, j, :] - min_)/(max_ - min_)
    return train_X_std, test_X_std


def train_test(train_X, test_X, train_y, test_y, cal_conn, device, use_test_windows=False):
    print("read_data")
    
    print(train_X.shape, test_X.shape)
    
    train_X, train_graphs, train_y = gen_features(train_X, train_y, device, cal_conn=cal_conn, 
                                                  window_size=100, overlap=90, augment=False)
    test_X, test_graphs, test_y = gen_features(test_X, test_y, device, cal_conn=cal_conn, 
                                               window_size=100, overlap=90, augment=False)
    
    train_X, test_X = standardize_data(train_X, test_X)
    
    clear_output()
    encoder = OneHotEncoder()
    train_y = encoder.fit_transform(train_y).toarray()
    test_y = encoder.transform(test_y).toarray()
    train_X = torch.Tensor(train_X).to(DEVICE)
    test_X = torch.Tensor(test_X).to(DEVICE)
    train_y= torch.Tensor(train_y).to(DEVICE)
    test_y = torch.Tensor(test_y).to(DEVICE)
    test_graphs= torch.Tensor(test_graphs).to(DEVICE)
    train_graphs = torch.Tensor(train_graphs).to(DEVICE)

    batch_size = 64
    data = TensorDataset(train_X, train_graphs, train_y)
    train_iter = torch.utils.data.DataLoader(data, batch_size, shuffle=True)
    data = TensorDataset(test_X, test_graphs, test_y)
    test_iter = torch.utils.data.DataLoader(data, batch_size, shuffle=False)
    
    return train_iter, test_iter

In [6]:
import glob, os
from sklearn.model_selection import KFold
    
def train_kfold(files_kfold):
    all_train_losses = []
    all_val_losses = []
    
    kf = KFold(n_splits=3, shuffle=True, random_state=2024)
    for k, (train_files_, val_files_) in enumerate(kf.split(files_kfold)):
        print("Kfold", k)
        train_files_ = [files_kfold[i] for i in train_files_]
        val_files_ = [files_kfold[i] for i in val_files_]
        train_iter, val_iter = train_test(train_files=train_files_, 
                                          test_files=val_files_,
                                          num_windows=200,
                                          cal_conn="corr",
                                          use_test_windows=True)
        
        DEVICE = torch.device('cpu')
        model = EEGModel(num_nodes=19, node_features=300, num_classes=2, num_windows=332, device=DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    
        print("Training model")
        model.train()
        for epoch in range(10):
            losses = 0
            for idx, (X, A, y) in enumerate(tqdm(train_iter)):
                optimizer.zero_grad()
                out = model(X, A)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
                losses += loss.item()*X.shape[0]
            losses = losses/len(train_iter.dataset)
            print("Epoch ", epoch+1, ":")
            print("Kfold train loss", losses)
            all_train_losses.append(losses)
            losses = 0
            
            for idx, (X, A, y) in enumerate(val_iter):
                optimizer.zero_grad()
                out = model(X, A)
                loss = criterion(out, y)
                losses += loss.item()*X.shape[0]
            losses = losses/len(val_iter.dataset)
            print("Kfold val loss", losses)
            all_val_losses.append(losses)
    clear_output()
    return all_train_losses, all_val_losses

In [7]:
import glob, os
from sklearn.model_selection import train_test_split
    
def train_val(files):
    all_train_losses = []
    all_val_losses = []

    train_subset_files, val_files = train_test_split(files, test_size=0.1, random_state=2024)
    train_iter, val_iter = train_test(train_files=train_subset_files, 
                                      test_files=val_files,
                                      num_windows=200,
                                      cal_conn="pearson",
                                      use_test_windows=True)
    
    DEVICE = torch.device('cpu')
    model = EEGModel(num_nodes=19, node_features=51, num_classes=2, num_windows=200, device=DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

    print("Training model")
    
    for epoch in range(100):
        losses = 0
        for idx, (X, A, y) in enumerate(tqdm(train_iter)):
            optimizer.zero_grad()
            out = model(X, A)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            losses += loss.item()*X.shape[0]
        losses = losses/len(train_iter.dataset)
        print("Epoch ", epoch+1, ":")
        print("train loss", losses)
        all_train_losses.append(losses)
        
        losses = 0
        for idx, (X, A, y) in enumerate(val_iter):
            optimizer.zero_grad()
            out = model(X, A)
            loss = criterion(out, y)
            losses += loss.item()*X.shape[0]
        losses = losses/len(val_iter.dataset)
        print("val loss", losses)
        all_val_losses.append(losses)
    clear_output()
    return all_train_losses, all_val_losses

In [8]:
"""
train_losses, val_losses = train_val(train_files)
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.plot(val_losses)
plt.xlim(-1, 100)
plt.xlabel("epochs")
plt.show()
"""

'\ntrain_losses, val_losses = train_val(train_files)\nimport matplotlib.pyplot as plt\nplt.plot(train_losses)\nplt.plot(val_losses)\nplt.xlim(-1, 100)\nplt.xlabel("epochs")\nplt.show()\n'

In [9]:
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, auc, roc_auc_score

def print_acc(model, data_iter):
    outs= []
    ys = []
    
    model.eval()
    with torch.no_grad():
        for X, A, y in data_iter:
            out = model(X, A)
            outs.extend(out.cpu().detach().numpy())
            ys.extend(y.cpu().detach().numpy())
    
    outs = np.array(outs)
    ys = np.array(ys)
    outs = np.argmax(outs, -1)
    ys = np.argmax(ys, -1)

    print("accuracy:", accuracy_score(outs, ys),
          "f1 score:", f1_score(outs, ys),
          "precision:",precision_score(outs, ys),
          "recall:", recall_score(outs, ys),
          "confusion matrix:", confusion_matrix(outs, ys))

    metrics = [accuracy_score(outs, ys), f1_score(outs, ys), 
               precision_score(outs, ys), recall_score(outs, ys),
               confusion_matrix(outs, ys)]
    
    return metrics


def train_model(model, num_epochs, data_iter):
    model.train()
    for epoch in tqdm(range(num_epochs)): 
        losses = 0
        model.train()
        for idx, (X, A, y) in enumerate(data_iter):
            optimizer.zero_grad()
            out = model(X, A)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        print_acc(model, test_iter)
    return model

In [10]:
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, auc, roc_auc_score

def print_acc(model, data_iter):
    outs= []
    ys = []
    
    model.eval()
    with torch.no_grad():
        for X, A, y in data_iter:
            out = model(X, A)
            outs.extend(out.cpu().detach().numpy())
            ys.extend(y.cpu().detach().numpy())
    
    outs = np.array(outs)
    ys = np.array(ys)
    outs = np.argmax(outs, -1)
    ys = np.argmax(ys, -1)

    print("accuracy:", accuracy_score(outs, ys),
          "f1 score:", f1_score(outs, ys),
          "precision:",precision_score(outs, ys),
          "recall:", recall_score(outs, ys),
          "confusion matrix:", confusion_matrix(outs, ys))

    metrics = [accuracy_score(outs, ys), f1_score(outs, ys), 
               precision_score(outs, ys), recall_score(outs, ys),
               confusion_matrix(outs, ys)]
    return metrics

def train_model(model, num_epochs, data_iter):

    model.train()
    for epoch in range(num_epochs): 
        losses = 0
        model.train()
        for idx, (X, A, y) in enumerate(data_iter):
            optimizer.zero_grad()
            out = model(X, A)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        print_acc(model, test_iter)
    return model

In [11]:
num_windows = 50
import copy
    
train_X, train_y = build_data(train_files, use_windows=False, num_windows=num_windows)
test_X, test_y = build_data(test_files, use_windows=False, num_windows=num_windows)

clear_output()

DEVICE = torch.device("cpu")
train_iter, test_iter = train_test(train_X=train_X, 
                                       test_X=test_X, 
                                       train_y=train_y, 
                                       test_y=test_y, 
                                       cal_conn="cc",
                                       device=DEVICE)  

In [16]:
final_model = EEGModel(num_nodes=19, node_features=51, num_classes=2, num_windows=50, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=torch.Tensor([10.0, 1.0]))
optimizer = torch.optim.Adam(final_model.parameters(),lr=1e-3)
final_model.train()
for epoch in range(100): 
    final_model.train()
    for idx, (X, A, y) in enumerate(train_iter):
        optimizer.zero_grad()
        out = final_model(X, A)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
    print_acc(final_model, test_iter)

accuracy: 0.45229681978798586 f1 score: 0.025157232704402517 precision: 0.012987012987012988 recall: 0.4 confusion matrix: [[126 152]
 [  3   2]]
accuracy: 0.5441696113074205 f1 score: 0.7034482758620689 precision: 0.9935064935064936 recall: 0.5444839857651246 confusion matrix: [[  1   1]
 [128 153]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [

KeyboardInterrupt: 

In [46]:
conn_types = ["pearson", "plv", "kld-entropy", "mi"]
num_epochs = 80
results = {}
    
for conn in conn_types:
    print(conn)    
    DEVICE = torch.device("cpu") #torch.device("mps")
    train_iter, test_iter = train_test(train_X=train_X[:, :50], 
                                       test_X=test_X[:, :50], 
                                       train_y=train_y, 
                                       test_y=test_y, 
                                       cal_conn=conn,
                                       device=DEVICE)    
    
    final_model = EEGModel(num_nodes=19, node_features=63, num_classes=2, num_windows=50, device=DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(final_model.parameters(),lr=1e-3)
    model = train_model(final_model, num_epochs, train_iter)
    results[conn] = print_acc(model, test_iter)
clear_output()

pearson
read_data
(1758, 50, 19, 100) (280, 50, 19, 100)
calculating connectivity



 0%|                                                  | 0/1758 [00:04<?, ?it/s]

KeyboardInterrupt: 

In [50]:
results

{'pearson': [0.6882591093117408,
  0.7158671586715867,
  0.7698412698412699,
  0.6689655172413793,
  array([[73, 29],
         [48, 97]])],
 'plv': [0.6356275303643725,
  0.7204968944099379,
  0.9206349206349206,
  0.5918367346938775,
  array([[ 41,  10],
         [ 80, 116]])]}

In [55]:
final_model = EEGModel(num_nodes=19, node_features=63, num_classes=2, num_windows=50, device=DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(final_model.parameters(),lr=1e-3)
model = train_model(final_model, 100, train_iter)

accuracy: 0.5141700404858299 f1 score: 0.6774193548387097 precision: 1.0 recall: 0.5121951219512195 confusion matrix: [[  1   0]
 [120 126]]
accuracy: 0.5101214574898786 f1 score: 0.675603217158177 precision: 1.0 recall: 0.5101214574898786 confusion matrix: [[  0   0]
 [121 126]]
accuracy: 0.5101214574898786 f1 score: 0.675603217158177 precision: 1.0 recall: 0.5101214574898786 confusion matrix: [[  0   0]
 [121 126]]
accuracy: 0.5870445344129555 f1 score: 0.7102272727272726 precision: 0.9920634920634921 recall: 0.5530973451327433 confusion matrix: [[ 20   1]
 [101 125]]
accuracy: 0.5101214574898786 f1 score: 0.675603217158177 precision: 1.0 recall: 0.5101214574898786 confusion matrix: [[  0   0]
 [121 126]]
accuracy: 0.7004048582995951 f1 score: 0.7658227848101266 precision: 0.9603174603174603 recall: 0.6368421052631579 confusion matrix: [[ 52   5]
 [ 69 121]]
accuracy: 0.7206477732793523 f1 score: 0.7752442996742671 precision: 0.9444444444444444 recall: 0.6574585635359116 confusion ma

KeyboardInterrupt: 

In [16]:
train_X.shape

(1722, 100, 19, 125)

In [50]:
list(train_y.squeeze()).count(1)

1497

In [51]:
251/1485

0.16902356902356902

In [None]:
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4763636363636364 f1 score: 0.06493506493506493 precision: 0.033783783783783786 recall: 0.8333333333333334 confusion matrix: [[126 143]
 [  1   5]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.5527272727272727 f1 score: 0.32786885245901637 precision: 0.20270270270270271 recall: 0.8571428571428571 confusion matrix: [[122 118]
 [  5  30]]
accuracy: 0.4690909090909091 f1 score: 0.02666666666666667 precision: 0.013513513513513514 recall: 1.0 confusion matrix: [[127 146]
 [  0   2]]
accuracy: 0.5236363636363637 f1 score: 0.20606060606060606 precision: 0.11486486486486487 recall: 1.0 confusion matrix: [[127 131]
 [  0  17]]
accuracy: 0.5890909090909091 f1 score: 0.4263959390862944 precision: 0.28378378378378377 recall: 0.8571428571428571 confusion matrix: [[120 106]
 [  7  42]]
accuracy: 0.5309090909090909 f1 score: 0.27932960893854747 precision: 0.16891891891891891 recall: 0.8064516129032258 confusion matrix: [[121 123]
 [  6  25]]
accuracy: 0.5418181818181819 f1 score: 0.31521739130434784 precision: 0.19594594594594594 recall: 0.8055555555555556 confusion matrix: [[120 119]
 [  7  29]]
accuracy: 0.5418181818181819 f1 score: 0.26744186046511625 precision: 0.1554054054054054 recall: 0.9583333333333334 confusion matrix: [[126 125]
 [  1  23]]
accuracy: 0.56 f1 score: 0.35978835978835977 precision: 0.22972972972972974 recall: 0.8292682926829268 confusion matrix: [[120 114]
 [  7  34]]
accuracy: 0.5418181818181819 f1 score: 0.27586206896551724 precision: 0.16216216216216217 recall: 0.9230769230769231 confusion matrix: [[125 124]
 [  2  24]]
accuracy: 0.6509090909090909 f1 score: 0.5555555555555556 precision: 0.40540540540540543 recall: 0.8823529411764706 confusion matrix: [[119  88]
 [  8  60]]
accuracy: 0.5418181818181819 f1 score: 0.26744186046511625 precision: 0.1554054054054054 recall: 0.9583333333333334 confusion matrix: [[126 125]
 [  1  23]]
accuracy: 0.5563636363636364 f1 score: 0.33695652173913043 precision: 0.20945945945945946 recall: 0.8611111111111112 confusion matrix: [[122 117]
 [  5  31]]
accuracy: 0.5563636363636364 f1 score: 0.33695652173913043 precision: 0.20945945945945946 recall: 0.8611111111111112 confusion matrix: [[122 117]
 [  5  31]]
accuracy: 0.6 f1 score: 0.4387755102040816 precision: 0.2905405405405405 recall: 0.8958333333333334 confusion matrix: [[122 105]
 [  5  43]]
accuracy: 0.5381818181818182 f1 score: 0.2658959537572254 precision: 0.1554054054054054 recall: 0.92 confusion matrix: [[125 125]
 [  2  23]]
accuracy: 0.6254545454545455 f1 score: 0.4975609756097561 precision: 0.34459459459459457 recall: 0.8947368421052632 confusion matrix: [[121  97]
 [  6  51]]
accuracy: 0.7345454545454545 f1 score: 0.7068273092369478 precision: 0.5945945945945946 recall: 0.8712871287128713 confusion matrix: [[114  60]
 [ 13  88]]
accuracy: 0.5381818181818182 f1 score: 0.2658959537572254 precision: 0.1554054054054054 recall: 0.92 confusion matrix: [[125 125]
 [  2  23]]
accuracy: 0.7163636363636363 f1 score: 0.6776859504132231 precision: 0.5540540540540541 recall: 0.8723404255319149 confusion matrix: [[115  66]
 [ 12  82]]
accuracy: 0.7818181818181819 f1 score: 0.782608695652174 precision: 0.7297297297297297 recall: 0.84375 confusion matrix: [[107  40]
 [ 20 108]]
accuracy: 0.6436363636363637 f1 score: 0.5242718446601942 precision: 0.36486486486486486 recall: 0.9310344827586207 confusion matrix: [[123  94]
 [  4  54]]
accuracy: 0.5490909090909091 f1 score: 0.29545454545454547 precision: 0.17567567567567569 recall: 0.9285714285714286 confusion matrix: [[125 122]
 [  2  26]]
accuracy: 0.7890909090909091 f1 score: 0.8104575163398693 precision: 0.8378378378378378 recall: 0.7848101265822784 confusion matrix: [[ 93  24]
 [ 34 124]]
accuracy: 0.6618181818181819 f1 score: 0.5592417061611374 precision: 0.39864864864864863 recall: 0.9365079365079365 confusion matrix: [[123  89]
 [  4  59]]
accuracy: 0.7236363636363636 f1 score: 0.6724137931034483 precision: 0.527027027027027 recall: 0.9285714285714286 confusion matrix: [[121  70]
 [  6  78]]
accuracy: 0.6436363636363637 f1 score: 0.5196078431372549 precision: 0.3581081081081081 recall: 0.9464285714285714 confusion matrix: [[124  95]
 [  3  53]]
accuracy: 0.5745454545454546 f1 score: 0.36065573770491804 precision: 0.22297297297297297 recall: 0.9428571428571428 confusion matrix: [[125 115]
 [  2  33]]
accuracy: 0.5927272727272728 f1 score: 0.4105263157894737 precision: 0.2635135135135135 recall: 0.9285714285714286 confusion matrix: [[124 109]
 [  3  39]]
accuracy: 0.8036363636363636 f1 score: 0.8187919463087249 precision: 0.8243243243243243 recall: 0.8133333333333334 confusion matrix: [[ 99  26]
 [ 28 122]]
accuracy: 0.5890909090909091 f1 score: 0.39572192513368987 precision: 0.25 recall: 0.9487179487179487 confusion matrix: [[125 111]
 [  2  37]]
accuracy: 0.6763636363636364 f1 score: 0.5898617511520737 precision: 0.43243243243243246 recall: 0.927536231884058 confusion matrix: [[122  84]
 [  5  64]]
accuracy: 0.5927272727272728 f1 score: 0.40425531914893614 precision: 0.25675675675675674 recall: 0.95 confusion matrix: [[125 110]
 [  2  38]]
accuracy: 0.5854545454545454 f1 score: 0.3870967741935484 precision: 0.24324324324324326 recall: 0.9473684210526315 confusion matrix: [[125 112]
 [  2  36]]
accuracy: 0.7236363636363636 f1 score: 0.6859504132231405 precision: 0.5608108108108109 recall: 0.8829787234042553 confusion matrix: [[116  65]
 [ 11  83]]
accuracy: 0.7818181818181819 f1 score: 0.7794117647058824 precision: 0.7162162162162162 recall: 0.8548387096774194 confusion matrix: [[109  42]
 [ 18 106]]
accuracy: 0.8072727272727273 f1 score: 0.8166089965397924 precision: 0.7972972972972973 recall: 0.8368794326241135 confusion matrix: [[104  30]
 [ 23 118]]
accuracy: 0.72 f1 score: 0.6831275720164609 precision: 0.5608108108108109 recall: 0.8736842105263158 confusion matrix: [[115  65]
 [ 12  83]]
accuracy: 0.6 f1 score: 0.4329896907216495 precision: 0.28378378378378377 recall: 0.9130434782608695 confusion matrix: [[123 106]
 [  4  42]]
accuracy: 0.6 f1 score: 0.4329896907216495 precision: 0.28378378378378377 recall: 0.9130434782608695 confusion matrix: [[123 106]
 [  4  42]]
accuracy: 0.76 f1 score: 0.7421875 precision: 0.6418918918918919 recall: 0.8796296296296297 confusion matrix: [[114  53]
 [ 13  95]]
accuracy: 0.7636363636363637 f1 score: 0.7547169811320755 precision: 0.6756756756756757 recall: 0.8547008547008547 confusion matrix: [[110  48]
 [ 17 100]]
accuracy: 0.7927272727272727 f1 score: 0.7956989247311828 precision: 0.75 recall: 0.8473282442748091 confusion matrix: [[107  37]
 [ 20 111]]
accuracy: 0.6836363636363636 f1 score: 0.6167400881057269 precision: 0.47297297297297297 recall: 0.8860759493670886 confusion matrix: [[118  78]
 [  9  70]]
accuracy: 0.7490909090909091 f1 score: 0.7335907335907336 precision: 0.6418918918918919 recall: 0.8558558558558559 confusion matrix: [[111  53]
 [ 16  95]]
accuracy: 0.7018181818181818 f1 score: 0.6583333333333333 precision: 0.5337837837837838 recall: 0.8586956521739131 confusion matrix: [[114  69]
 [ 13  79]]
accuracy: 0.7963636363636364 f1 score: 0.8 precision: 0.7567567567567568 recall: 0.8484848484848485 confusion matrix: [[107  36]
 [ 20 112]]
accuracy: 0.6654545454545454 f1 score: 0.5892857142857143 precision: 0.44594594594594594 recall: 0.868421052631579 confusion matrix: [[117  82]
 [ 10  66]]
accuracy: 0.6545454545454545 f1 score: 0.5581395348837209 precision: 0.40540540540540543 recall: 0.8955223880597015 confusion matrix: [[120  88]
 [  7  60]]
accuracy: 0.7454545454545455 f1 score: 0.7865853658536586 precision: 0.8716216216216216 recall: 0.7166666666666667 confusion matrix: [[ 76  19]
 [ 51 129]]
accuracy: 0.7781818181818182 f1 score: 0.779783393501805 precision: 0.7297297297297297 recall: 0.8372093023255814 confusion matrix: [[106  40]
 [ 21 108]]

In [None]:
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4416961130742049 f1 score: 0.07058823529411765 precision: 0.03896103896103896 recall: 0.375 confusion matrix: [[119 148]
 [ 10   6]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.5477031802120141 f1 score: 0.7064220183486238 precision: 1.0 recall: 0.5460992907801419 confusion matrix: [[  1   0]
 [128 154]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.5053003533568905 f1 score: 0.3 precision: 0.19480519480519481 recall: 0.6521739130434783 confusion matrix: [[113 124]
 [ 16  30]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.519434628975265 f1 score: 0.30612244897959184 precision: 0.19480519480519481 recall: 0.7142857142857143 confusion matrix: [[117 124]
 [ 12  30]]
accuracy: 0.45936395759717313 f1 score: 0.025477707006369428 precision: 0.012987012987012988 recall: 0.6666666666666666 confusion matrix: [[128 152]
 [  1   2]]
accuracy: 0.5159010600706714 f1 score: 0.29743589743589743 precision: 0.18831168831168832 recall: 0.7073170731707317 confusion matrix: [[117 125]
 [ 12  29]]
accuracy: 0.5335689045936396 f1 score: 0.2903225806451613 precision: 0.17532467532467533 recall: 0.84375 confusion matrix: [[124 127]
 [  5  27]]
accuracy: 0.4628975265017668 f1 score: 0.02564102564102564 precision: 0.012987012987012988 recall: 1.0 confusion matrix: [[129 152]
 [  0   2]]
accuracy: 0.45936395759717313 f1 score: 0.025477707006369428 precision: 0.012987012987012988 recall: 0.6666666666666666 confusion matrix: [[128 152]
 [  1   2]]
accuracy: 0.519434628975265 f1 score: 0.22727272727272727 precision: 0.12987012987012986 recall: 0.9090909090909091 confusion matrix: [[127 134]
 [  2  20]]
accuracy: 0.5123674911660777 f1 score: 0.20689655172413793 precision: 0.11688311688311688 recall: 0.9 confusion matrix: [[127 136]
 [  2  18]]
accuracy: 0.5371024734982333 f1 score: 0.2994652406417112 precision: 0.18181818181818182 recall: 0.8484848484848485 confusion matrix: [[124 126]
 [  5  28]]
accuracy: 0.568904593639576 f1 score: 0.46956521739130436 precision: 0.35064935064935066 recall: 0.7105263157894737 confusion matrix: [[107 100]
 [ 22  54]]
accuracy: 0.558303886925795 f1 score: 0.37185929648241206 precision: 0.24025974025974026 recall: 0.8222222222222222 confusion matrix: [[121 117]
 [  8  37]]
accuracy: 0.5441696113074205 f1 score: 0.3027027027027027 precision: 0.18181818181818182 recall: 0.9032258064516129 confusion matrix: [[126 126]
 [  3  28]]
accuracy: 0.5547703180212014 f1 score: 0.36363636363636365 precision: 0.23376623376623376 recall: 0.8181818181818182 confusion matrix: [[121 118]
 [  8  36]]
accuracy: 0.5406360424028268 f1 score: 0.3157894736842105 precision: 0.19480519480519481 recall: 0.8333333333333334 confusion matrix: [[123 124]
 [  6  30]]
accuracy: 0.6219081272084805 f1 score: 0.5114155251141552 precision: 0.36363636363636365 recall: 0.8615384615384616 confusion matrix: [[120  98]
 [  9  56]]
accuracy: 0.5441696113074205 f1 score: 0.29508196721311475 precision: 0.17532467532467533 recall: 0.9310344827586207 confusion matrix: [[127 127]
 [  2  27]]
accuracy: 0.5512367491166078 f1 score: 0.31351351351351353 precision: 0.18831168831168832 recall: 0.9354838709677419 confusion matrix: [[127 125]
 [  2  29]]
accuracy: 0.6855123674911661 f1 score: 0.6147186147186147 precision: 0.461038961038961 recall: 0.922077922077922 confusion matrix: [[123  83]
 [  6  71]]
accuracy: 0.5971731448763251 f1 score: 0.42424242424242425 precision: 0.2727272727272727 recall: 0.9545454545454546 confusion matrix: [[127 112]
 [  2  42]]
accuracy: 0.6007067137809188 f1 score: 0.4321608040201005 precision: 0.2792207792207792 recall: 0.9555555555555556 confusion matrix: [[127 111]
 [  2  43]]
accuracy: 0.6183745583038869 f1 score: 0.47572815533980584 precision: 0.3181818181818182 recall: 0.9423076923076923 confusion matrix: [[126 105]
 [  3  49]]
accuracy: 0.6148409893992933 f1 score: 0.4682926829268293 precision: 0.3116883116883117 recall: 0.9411764705882353 confusion matrix: [[126 106]
 [  3  48]]
accuracy: 0.5971731448763251 f1 score: 0.42424242424242425 precision: 0.2727272727272727 recall: 0.9545454545454546 confusion matrix: [[127 112]
 [  2  42]]
accuracy: 0.784452296819788 f1 score: 0.7813620071684588 precision: 0.7077922077922078 recall: 0.872 confusion matrix: [[113  45]
 [ 16 109]]
accuracy: 0.8021201413427562 f1 score: 0.8145695364238411 precision: 0.7987012987012987 recall: 0.831081081081081 confusion matrix: [[104  31]
 [ 25 123]]
accuracy: 0.7985865724381626 f1 score: 0.8054607508532423 precision: 0.7662337662337663 recall: 0.8489208633093526 confusion matrix: [[108  36]
 [ 21 118]]
accuracy: 0.6537102473498233 f1 score: 0.5663716814159292 precision: 0.4155844155844156 recall: 0.8888888888888888 confusion matrix: [[121  90]
 [  8  64]]
accuracy: 0.7703180212014135 f1 score: 0.7547169811320755 precision: 0.6493506493506493 recall: 0.9009009009009009 confusion matrix: [[118  54]
 [ 11 100]]
accuracy: 0.6713780918727915 f1 score: 0.5829596412556054 precision: 0.42207792207792205 recall: 0.9420289855072463 confusion matrix: [[125  89]
 [  4  65]]
accuracy: 0.6784452296819788 f1 score: 0.5955555555555555 precision: 0.43506493506493504 recall: 0.9436619718309859 confusion matrix: [[125  87]
 [  4  67]]
accuracy: 0.5901060070671378 f1 score: 0.42 precision: 0.2727272727272727 recall: 0.9130434782608695 confusion matrix: [[125 112]
 [  4  42]]
accuracy: 0.5901060070671378 f1 score: 0.40816326530612246 precision: 0.2597402597402597 recall: 0.9523809523809523 confusion matrix: [[127 114]
 [  2  40]]
accuracy: 0.6607773851590106 f1 score: 0.5636363636363636 precision: 0.4025974025974026 recall: 0.9393939393939394 confusion matrix: [[125  92]
 [  4  62]]
accuracy: 0.7561837455830389 f1 score: 0.7315175097276264 precision: 0.6103896103896104 recall: 0.912621359223301 confusion matrix: [[120  60]
 [  9  94]]
accuracy: 0.6996466431095406 f1 score: 0.6443514644351465 precision: 0.5 recall: 0.9058823529411765 confusion matrix: [[121  77]
 [  8  77]]
accuracy: 0.8197879858657244 f1 score: 0.8235294117647058 precision: 0.7727272727272727 recall: 0.8814814814814815 confusion matrix: [[113  35]
 [ 16 119]]
accuracy: 0.6819787985865724 f1 score: 0.6153846153846154 precision: 0.4675324675324675 recall: 0.9 confusion matrix: [[121  82]
 [  8  72]]
accuracy: 0.7067137809187279 f1 score: 0.6693227091633466 precision: 0.5454545454545454 recall: 0.865979381443299 confusion matrix: [[116  70]
 [ 13  84]]
accuracy: 0.6501766784452296 f1 score: 0.547945205479452 precision: 0.38961038961038963 recall: 0.9230769230769231 confusion matrix: [[124  94]
 [  5  60]]
accuracy: 0.6537102473498233 f1 score: 0.5504587155963303 precision: 0.38961038961038963 recall: 0.9375 confusion matrix: [[125  94]
 [  4  60]]
accuracy: 0.7208480565371025 f1 score: 0.680161943319838 precision: 0.5454545454545454 recall: 0.9032258064516129 confusion matrix: [[120  70]
 [  9  84]]
accuracy: 0.8056537102473498 f1 score: 0.8028673835125448 precision: 0.7272727272727273 recall: 0.896 confusion matrix: [[116  42]
 [ 13 112]]
accuracy: 0.6360424028268551 f1 score: 0.5164319248826291 precision: 0.35714285714285715 recall: 0.9322033898305084 confusion matrix: [[125  99]
 [  4  55]]
accuracy: 0.6643109540636042 f1 score: 0.5814977973568282 precision: 0.42857142857142855 recall: 0.9041095890410958 confusion matrix: [[122  88]
 [  7  66]]
accuracy: 0.657243816254417 f1 score: 0.5650224215246636 precision: 0.4090909090909091 recall: 0.9130434782608695 confusion matrix: [[123  91]
 [  6  63]]
accuracy: 0.6890459363957597 f1 score: 0.6302521008403361 precision: 0.487012987012987 recall: 0.8928571428571429 confusion matrix: [[120  79]
 [  9  75]]
accuracy: 0.7950530035335689 f1 score: 0.7986111111111112 precision: 0.7467532467532467 recall: 0.8582089552238806 confusion matrix: [[110  39]
 [ 19 115]]
accuracy: 0.7102473498233216 f1 score: 0.6796875 precision: 0.564935064935065 recall: 0.8529411764705882 confusion matrix: [[114  67]
 [ 15  87]]
accuracy: 0.7208480565371025 f1 score: 0.680161943319838 precision: 0.5454545454545454 recall: 0.9032258064516129 confusion matrix: [[120  70]
 [  9  84]]
accuracy: 0.7067137809187279 f1 score: 0.6527196652719666 precision: 0.5064935064935064 recall: 0.9176470588235294 confusion matrix: [[122  76]
 [  7  78]]
accuracy: 0.7243816254416962 f1 score: 0.6829268292682927 precision: 0.5454545454545454 recall: 0.9130434782608695 confusion matrix: [[121  70]
 [  8  84]]
accuracy: 0.6360424028268551 f1 score: 0.5209302325581395 precision: 0.36363636363636365 recall: 0.9180327868852459 confusion matrix: [[124  98]
 [  5  56]]
accuracy: 0.607773851590106 f1 score: 0.44776119402985076 precision: 0.2922077922077922 recall: 0.9574468085106383 confusion matrix: [[127 109]
 [  2  45]]
accuracy: 0.6678445229681979 f1 score: 0.584070796460177 precision: 0.42857142857142855 recall: 0.9166666666666666 confusion matrix: [[123  88]
 [  6  66]]
accuracy: 0.657243816254417 f1 score: 0.5650224215246636 precision: 0.4090909090909091 recall: 0.9130434782608695 confusion matrix: [[123  91]
 [  6  63]]
accuracy: 0.6819787985865724 f1 score: 0.6186440677966102 precision: 0.474025974025974 recall: 0.8902439024390244 confusion matrix: [[120  81]
 [  9  73]]
accuracy: 0.6431095406360424 f1 score: 0.5429864253393665 precision: 0.38961038961038963 recall: 0.8955223880597015 confusion matrix: [[122  94]
 [  7  60]]
accuracy: 0.5936395759717314 f1 score: 0.42786069651741293 precision: 0.2792207792207792 recall: 0.9148936170212766 confusion matrix: [[125 111]
 [  4  43]]
accuracy: 0.6890459363957597 f1 score: 0.6206896551724138 precision: 0.4675324675324675 recall: 0.9230769230769231 confusion matrix: [[123  82]
 [  6  72]]
accuracy: 0.6643109540636042 f1 score: 0.5739910313901345 precision: 0.4155844155844156 recall: 0.927536231884058 confusion matrix: [[124  90]
 [  5  64]]
accuracy: 0.7632508833922261 f1 score: 0.7490636704119851 precision: 0.6493506493506493 recall: 0.8849557522123894 confusion matrix: [[116  54]
 [ 13 100]]
accuracy: 0.7067137809187279 f1 score: 0.6527196652719666 precision: 0.5064935064935064 recall: 0.9176470588235294 confusion matrix: [[122  76]
 [  7  78]]
accuracy: 0.6855123674911661 f1 score: 0.6244725738396625 precision: 0.4805194805194805 recall: 0.891566265060241 confusion matrix: [[120  80]
 [  9  74]]
accuracy: 0.6607773851590106 f1 score: 0.5932203389830508 precision: 0.45454545454545453 recall: 0.8536585365853658 confusion matrix: [[117  84]
 [ 12  70]]
accuracy: 0.6325088339222615 f1 score: 0.5229357798165137 precision: 0.37012987012987014 recall: 0.890625 confusion matrix: [[122  97]
 [  7  57]]
accuracy: 0.7385159010600707 f1 score: 0.7153846153846154 precision: 0.6038961038961039 recall: 0.8773584905660378 confusion matrix: [[116  61]
 [ 13  93]]
accuracy: 0.6678445229681979 f1 score: 0.5877192982456141 precision: 0.43506493506493504 recall: 0.9054054054054054 confusion matrix: [[122  87]
 [  7  67]]
accuracy: 0.6607773851590106 f1 score: 0.5789473684210527 precision: 0.42857142857142855 recall: 0.8918918918918919 confusion matrix: [[121  88]
 [  8  66]]
accuracy: 0.8021201413427562 f1 score: 0.8108108108108109 precision: 0.7792207792207793 recall: 0.8450704225352113 confusion matrix: [[107  34]
 [ 22 120]]
accuracy: 0.7208480565371025 f1 score: 0.6926070038910506 precision: 0.577922077922078 recall: 0.8640776699029126 confusion matrix: [[115  65]
 [ 14  89]]
accuracy: 0.7031802120141343 f1 score: 0.6440677966101694 precision: 0.4935064935064935 recall: 0.926829268292683 confusion matrix: [[123  78]
 [  6  76]]
accuracy: 0.6855123674911661 f1 score: 0.6147186147186147 precision: 0.461038961038961 recall: 0.922077922077922 confusion matrix: [[123  83]
 [  6  71]]
accuracy: 0.784452296819788 f1 score: 0.7765567765567766 precision: 0.6883116883116883 recall: 0.8907563025210085 confusion matrix: [[116  48]
 [ 13 106]]
accuracy: 0.6607773851590106 f1 score: 0.5752212389380531 precision: 0.42207792207792205 recall: 0.9027777777777778 confusion matrix: [[122  89]
 [  7  65]]
accuracy: 0.6678445229681979 f1 score: 0.6050420168067226 precision: 0.4675324675324675 recall: 0.8571428571428571 confusion matrix: [[117  82]
 [ 12  72]]
accuracy: 0.7879858657243817 f1 score: 0.7945205479452054 precision: 0.7532467532467533 recall: 0.8405797101449275 confusion matrix: [[107  38]
 [ 22 116]]
accuracy: 0.7950530035335689 f1 score: 0.795774647887324 precision: 0.7337662337662337 recall: 0.8692307692307693 confusion matrix: [[112  41]
 [ 17 113]]
accuracy: 0.7137809187279152 f1 score: 0.6872586872586872 precision: 0.577922077922078 recall: 0.8476190476190476 confusion matrix: [[113  65]
 [ 16  89]]
accuracy: 0.6678445229681979 f1 score: 0.5877192982456141 precision: 0.43506493506493504 recall: 0.9054054054054054 confusion matrix: [[122  87]
 [  7  67]]
accuracy: 0.7137809187279152 f1 score: 0.6823529411764706 precision: 0.564935064935065 recall: 0.8613861386138614 confusion matrix: [[115  67]
 [ 14  87]]
accuracy: 0.6996466431095406 f1 score: 0.6530612244897959 precision: 0.5194805194805194 recall: 0.8791208791208791 confusion matrix: [[118  74]
 [ 11  80]]
accuracy: 0.6749116607773852 f1 score: 0.6101694915254238 precision: 0.4675324675324675 recall: 0.8780487804878049 confusion matrix: [[119  82]
 [ 10  72]]
accuracy: 0.6678445229681979 f1 score: 0.5948275862068966 precision: 0.44805194805194803 recall: 0.8846153846153846 confusion matrix: [[120  85]
 [  9  69]]
accuracy: 0.6925795053003534 f1 score: 0.6419753086419753 precision: 0.5064935064935064 recall: 0.8764044943820225 confusion matrix: [[118  76]
 [ 11  78]]
accuracy: 0.7597173144876325 f1 score: 0.7424242424242424 precision: 0.6363636363636364 recall: 0.8909090909090909 confusion matrix: [[117  56]
 [ 12  98]]
accuracy: 0.7137809187279152 f1 score: 0.6693877551020408 precision: 0.5324675324675324 recall: 0.9010989010989011 confusion matrix: [[120  72]
 [  9  82]]
accuracy: 0.6431095406360424 f1 score: 0.5429864253393665 precision: 0.38961038961038963 recall: 0.8955223880597015 confusion matrix: [[122  94]
 [  7  60]]
accuracy: 0.696113074204947 f1 score: 0.6446280991735537 precision: 0.5064935064935064 recall: 0.8863636363636364 confusion matrix: [[119  76]
 [ 10  78]]
accuracy: 0.773851590106007 f1 score: 0.7681159420289855 precision: 0.6883116883116883 recall: 0.8688524590163934 confusion matrix: [[113  48]
 [ 16 106]]
accuracy: 0.7420494699646644 f1 score: 0.7159533073929961 precision: 0.5974025974025974 recall: 0.8932038834951457 confusion matrix: [[118  62]
 [ 11  92]]
