In [1]:
import pickle
import argparse
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss, BCELoss
from torch.optim import Adam
from sklearn.metrics import balanced_accuracy_score

from utils.training import train
from utils.evaluation import score
from utils.data import Data, SpikeDetectionDataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
expe_seed = 0
method = 'Fukumori'
parameters = [0]
sfreq = 100  # Sampling frequency
batch_size = 16


In [3]:
import scipy
channel_fname = 'database/Neuropoly_MEEG_database/channel_ctf_acc1.mat'
channel_mat = scipy.io.loadmat(channel_fname, chars_as_strings=1)

In [4]:
wanted_channels = []
wanted_channel_type = 'C3'
for i in range(channel_mat['Channel'].shape[1]):
    channel_type = channel_mat['Channel'][0, i]['Name'].tolist()[0]
    if channel_type in wanted_channel_type:
        wanted_channels.append(i)

In [5]:
wanted_channels

[326]

In [6]:
path_root = 'database/'
dataset = Data(path_root, 'saw_EST', 'EEG', True)
all_dataset = dataset.all_datasets()

2022-05-19 14:21:18.948 | INFO     | utils.data:get_all_datasets:196 - Recover data for Neuropoly_MEEG_database


[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7') array(['spaced_out'], dtype='<U10')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7') array(['spaced_out'], dtype='<U10')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_E

2022-05-19 14:21:45.458 | INFO     | utils.data:get_dataset:170 - Label creation: No Spike / Spikes mapped on labels [0 1]


[[array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7') array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['cardiac'], dtype='<U7')
  array(['transient_notch'], dtype='<U15')]]
[[array(['saw_EST'], dtype='<U7') array(['cardiac'], dtype='<U7')
  array(['spaced_out_5'], dtype='<U12')
  array(['transient_notch'], dtype='<U15')]]


In [7]:
n_epochs = 1
patience = 10

assert method in ("Fukumori")

num_workers = 0  # Number of processes to use for the data loading process; 0 is the main Python process

results = []


In [8]:
data = all_dataset[0]['Neuropoly_MEEG_database']


In [9]:
labels = all_dataset[1]['Neuropoly_MEEG_database']

In [31]:
data = data[:, :, np.newaxis]

In [32]:
from sklearn.model_selection import train_test_split

data_train, data_test, labels_train, labels_test = train_test_split(data, labels )
data_train, data_val, labels_train, labels_val = train_test_split(data_train, labels_train)

In [33]:
data.shape

(655, 201, 1)

In [34]:
dataset_train = SpikeDetectionDataset(data_train, labels_train)
dataset_val = SpikeDetectionDataset(data_val, labels_val)
dataset_test = SpikeDetectionDataset(data_test, labels_test)

In [35]:
labels.shape

(655,)

In [36]:
from torch.utils.data import DataLoader

loaders_train = DataLoader(dataset_train, batch_size=batch_size)
loaders_val = DataLoader(dataset_val, batch_size=batch_size)
loaders_test = DataLoader(dataset_test, batch_size=batch_size)


In [37]:
input_size = loaders_train.dataset[0][0].shape[1]



In [54]:
import utils.model as model
import utils.training as training
import importlib
importlib.reload(model)
importlib.reload(training)

if method == "Fukumori":
    model = model.fukumori2021RNN(input_size=input_size)

lr = 1e-3  # Learning rate
optimizer = Adam(params=model.parameters(), lr=lr)

criterion = BCELoss()

In [55]:
# Train Model
best_model, history = training.train(
    model,
    method,
    loaders_train,
    loaders_val,
    optimizer,
    criterion,
    parameters,
    1,
    patience,
)


epoch 	 train_loss 	 valid_loss 	 train_perf 	 valid_perf
--------------------------------------------------------------------------------
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16, 201, 8])
torch.Size([16, 8, 50])
torch.Size([16, 50, 8])
torch.Size([16, 50, 8])
torch.Size([16, 8, 12])
torch.Size([16, 1])
torch.Size([16,