In [1]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score
import mne
from mne import Epochs, pick_types, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf,read_raw_edf,read_raw_gdf
from mne.datasets import eegbci
from mne.decoding import CSP

In [2]:
def get_data():
    path = "dataset\\bci_dataset\\train"
    files = os.listdir(path)
    data_path = []
    for f in files:
        file = os.path.join(path,f)
        data_path.append(file)
    
    tmin, tmax = -0.5, 4.
    event_id = dict(left=1, right = 2, foot=3,tongue=4)

    raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])

    # strip channel names of "." characters
    raw.rename_channels(lambda x: x.strip('.'))

    # Apply band-pass filter
    #raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge')

    events, _ = events_from_annotations(raw)
    
    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')
    #delete eog band
    picks = np.delete(picks,[22,23,24])
    # Read epochs (train will be done only between 1 and 2s)
    # Testing will be done with a running classifier
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)
    epochs.crop(tmin=1., tmax=None)
    labels = epochs.events[:, 2] -1
    return epochs.get_data()[:, :, :256], labels


epochs_data, labels = get_data()

Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A04T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 600914  =      0.000 ...  2403.656 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A05T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686119  =      0.000 ...  2744.476 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A06T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 678979  =      0.000 ...  2715.916 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A07T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 681070  =      0.000 ...  2724.280 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A08T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 675269  =      0.000 ...  2701.076 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Extracting EDF parameters from c:\Users\asus\Desktop\Motorimagery_for_gamification\dataset\bci_dataset\train\A09T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 673327  =      0.000 ...  2693.308 secs...


  etmode = np.fromstring(etmode, UINT8).tolist()[0]
  raw = concatenate_raws([read_raw_gdf(f, preload=True) for f in data_path])


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Not setting metadata
Not setting metadata
289 matching events found
No baseline correction applied
0 projection items activated
Loading data for 289 events and 1126 original time points ...
9 bad epochs dropped


In [3]:
# 280 events
# 22 channel
# 256 Time
print(epochs_data.shape)
print(labels.shape)
print(epochs_data.shape[1])
print(labels)

(280, 22, 256)
(280,)
22
[3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [4]:
def amplitude(x):
    """Data augmentation for adding noise to amplitude of sepctral image
    ----------
    X : array, shape (n_channels, n_times)
        The input signals.
    Returns
    -------
    X_t : array, shape (n_channels, n_times)
        reconstructed new time series from inverse STFT
    """
    print("Hello")



In [5]:
import torch   
import torch.optim as optim  
from torch.utils.data import Dataset, DataLoader  
from torch.utils.data import Subset  
from torch import nn  
import torch.nn.functional as F  
from torch.utils.data import RandomSampler  
from torch.utils.data import SequentialSampler  

from common import EpochsDataset  

cv = ShuffleSplit(10, test_size=0.2, random_state=42)
cv_split = cv.split(epochs_data)
train_idx, test_idx = next(cv_split)


def scale(X):
    """Standard scaling of data along the last dimention.
    Parameters
    ----------
    X : array, shape (n_channels, n_times)
        The input signals.
    Returns
    -------
    X_t : array, shape (n_channels, n_times)
        The scaled signals.
    """
    return X / 2e-5

dataset = EpochsDataset(epochs_data, labels, transform=scale)

ds_train, ds_valid = Subset(dataset, train_idx), Subset(dataset, test_idx)

batch_size_train = len(ds_train)
batch_size_valid = len(ds_valid)
sampler_train = RandomSampler(ds_train)
sampler_valid = SequentialSampler(ds_valid)

# create loaders
num_workers = 0
loader_train = \
    DataLoader(ds_train, batch_size=batch_size_train,
               num_workers=num_workers, sampler=sampler_train)
loader_valid = \
    DataLoader(ds_valid, batch_size=batch_size_valid,
               num_workers=num_workers, sampler=sampler_valid)

In [6]:
class SingleNet(nn.Module):
    def __init__(self):
        super().__init__()

        # define model architecture
        # torch.Size([36, 1, 64, 256])
        # temporal convolutional kernel 25 filter with size(11,1)
        #input [224, 1, 22, 256]
        #A block
        self.temporalConv = nn.Conv1d(in_channels=1,out_channels=8,kernel_size=(11,1),stride=1)
        self.spatialConv = nn.Conv1d(in_channels=8,out_channels=8,kernel_size=(12,1),stride=1)
        self.maxPooling = nn.MaxPool2d((1,3))
        
        #B block
        # 1st convo output [25,1,371]
        #input [224, 8, 1, 85]
        self.firstConv = nn.Conv1d(in_channels=8, out_channels=8,kernel_size=(1,1))
        self.secondConv = nn.Conv1d(in_channels=8, out_channels=33,kernel_size=(1,1))
        self.thirdConv = nn.Conv1d(in_channels=33,out_channels=33,kernel_size=(1,1))
        self.maxPooling2 = nn.MaxPool2d((1,3))
        
        #C block
        #input [224, 41, 1, 28]
        self.classify = nn.Conv1d(in_channels=41,out_channels=41,kernel_size=(1,11))
        self.maxPooling3 = nn.MaxPool2d((1,3))
        self.flatten = nn.Flatten()
        self.elu = nn.ELU()
        self.fully = nn.Linear(246,4)
        
        
    def forward(self, x):
        #Conv1d(1,ch,kernel_size=10) on tensor [batch=1, channels=1, time=400]
        x = self.temporalConv(x)
        x= self.spatialConv(x)
        x = self.elu(x)
        x = self.maxPooling(x)
        
        out2 = self.firstConv(x)
        out2 = self.secondConv(out2)
        out2 = self.thirdConv(out2)
        out3 = torch.cat((x,out2),1)
        out3 = self.maxPooling2(out3)
        
        out3 = self.classify(out3)
        out3 = self.maxPooling3(out3)
        out3 = self.flatten(out3)
        out3 = self.elu(out3)
        out3 = self.fully(out3)
        out3 = self.elu(out3)
        out3 = F.log_softmax(out3, dim=1)
        
        return out3

In [7]:
#device = 'cuda'
device = 'cpu'
model = SingleNet()


In [8]:
# Train
from common import train

lr = 1e-4
n_epochs = 50
patience = 100

model.to(device=device) 
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=1e-4)

train(model, loader_train, loader_valid, optimizer, n_epochs, patience, device)


Starting epoch 1 / 50


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
avg train loss: 1.3512: 100%|██████████| 1/1 [00:00<00:00,  2.72it/s]
avg val loss: 1.3289: 100%|██████████| 1/1 [00:00<00:00, 29.49it/s]


---  Accuracy : 0.7678571343421936 


best val loss inf -> 1.3289

Starting epoch 2 / 50


avg train loss: 1.3263: 100%|██████████| 1/1 [00:00<00:00,  5.65it/s]
avg val loss: 1.3031: 100%|██████████| 1/1 [00:00<00:00, 40.11it/s]


---  Accuracy : 0.9285714030265808 


best val loss 1.3289 -> 1.3031

Starting epoch 3 / 50


avg train loss: 1.3016: 100%|██████████| 1/1 [00:00<00:00,  5.51it/s]
avg val loss: 1.2774: 100%|██████████| 1/1 [00:00<00:00, 52.77it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.3031 -> 1.2774

Starting epoch 4 / 50


avg train loss: 1.2771: 100%|██████████| 1/1 [00:00<00:00,  4.50it/s]
avg val loss: 1.2518: 100%|██████████| 1/1 [00:00<00:00, 31.33it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.2774 -> 1.2518

Starting epoch 5 / 50


avg train loss: 1.2527: 100%|██████████| 1/1 [00:00<00:00,  3.14it/s]
avg val loss: 1.2263: 100%|██████████| 1/1 [00:00<00:00, 35.81it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.2518 -> 1.2263

Starting epoch 6 / 50


avg train loss: 1.2283: 100%|██████████| 1/1 [00:00<00:00,  5.04it/s]
avg val loss: 1.2009: 100%|██████████| 1/1 [00:00<00:00, 45.58it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.2263 -> 1.2009

Starting epoch 7 / 50


avg train loss: 1.2041: 100%|██████████| 1/1 [00:00<00:00,  6.16it/s]
avg val loss: 1.1755: 100%|██████████| 1/1 [00:00<00:00, 45.58it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.2009 -> 1.1755

Starting epoch 8 / 50


avg train loss: 1.1800: 100%|██████████| 1/1 [00:00<00:00,  5.51it/s]
avg val loss: 1.1503: 100%|██████████| 1/1 [00:00<00:00, 38.57it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.1755 -> 1.1503

Starting epoch 9 / 50


avg train loss: 1.1560: 100%|██████████| 1/1 [00:00<00:00,  5.31it/s]
avg val loss: 1.1250: 100%|██████████| 1/1 [00:00<00:00, 40.11it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.1503 -> 1.1250

Starting epoch 10 / 50


avg train loss: 1.1320: 100%|██████████| 1/1 [00:00<00:00,  5.62it/s]
avg val loss: 1.0998: 100%|██████████| 1/1 [00:00<00:00, 44.55it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.1250 -> 1.0998

Starting epoch 11 / 50


avg train loss: 1.1081: 100%|██████████| 1/1 [00:00<00:00,  5.05it/s]
avg val loss: 1.0747: 100%|██████████| 1/1 [00:00<00:00, 41.78it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.0998 -> 1.0747

Starting epoch 12 / 50


avg train loss: 1.0842: 100%|██████████| 1/1 [00:00<00:00,  5.60it/s]
avg val loss: 1.0496: 100%|██████████| 1/1 [00:00<00:00, 40.11it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.0747 -> 1.0496

Starting epoch 13 / 50


avg train loss: 1.0605: 100%|██████████| 1/1 [00:00<00:00,  5.70it/s]
avg val loss: 1.0245: 100%|██████████| 1/1 [00:00<00:00, 35.81it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.0496 -> 1.0245

Starting epoch 14 / 50


avg train loss: 1.0367: 100%|██████████| 1/1 [00:00<00:00,  6.41it/s]
avg val loss: 0.9995: 100%|██████████| 1/1 [00:00<00:00, 44.54it/s]


---  Accuracy : 0.9642857313156128 


best val loss 1.0245 -> 0.9995

Starting epoch 15 / 50


avg train loss: 1.0131: 100%|██████████| 1/1 [00:00<00:00,  5.70it/s]
avg val loss: 0.9745: 100%|██████████| 1/1 [00:00<00:00, 52.77it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.9995 -> 0.9745

Starting epoch 16 / 50


avg train loss: 0.9895: 100%|██████████| 1/1 [00:00<00:00,  6.55it/s]
avg val loss: 0.9496: 100%|██████████| 1/1 [00:00<00:00, 55.70it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.9745 -> 0.9496

Starting epoch 17 / 50


avg train loss: 0.9660: 100%|██████████| 1/1 [00:00<00:00,  7.21it/s]
avg val loss: 0.9247: 100%|██████████| 1/1 [00:00<00:00, 52.78it/s]

---  Accuracy : 0.9642857313156128 


best val loss 0.9496 -> 0.9247






Starting epoch 18 / 50


avg train loss: 0.9426: 100%|██████████| 1/1 [00:00<00:00,  6.28it/s]
avg val loss: 0.8999: 100%|██████████| 1/1 [00:00<00:00, 50.13it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.9247 -> 0.8999

Starting epoch 19 / 50


avg train loss: 0.9193: 100%|██████████| 1/1 [00:00<00:00,  5.96it/s]
avg val loss: 0.8752: 100%|██████████| 1/1 [00:00<00:00, 50.13it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.8999 -> 0.8752

Starting epoch 20 / 50


avg train loss: 0.8961: 100%|██████████| 1/1 [00:00<00:00,  6.65it/s]
avg val loss: 0.8506: 100%|██████████| 1/1 [00:00<00:00, 58.98it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.8752 -> 0.8506

Starting epoch 21 / 50


avg train loss: 0.8731: 100%|██████████| 1/1 [00:00<00:00,  6.82it/s]
avg val loss: 0.8262: 100%|██████████| 1/1 [00:00<00:00, 62.67it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.8506 -> 0.8262

Starting epoch 22 / 50


avg train loss: 0.8503: 100%|██████████| 1/1 [00:00<00:00,  6.15it/s]
avg val loss: 0.8019: 100%|██████████| 1/1 [00:00<00:00, 47.74it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.8262 -> 0.8019

Starting epoch 23 / 50


avg train loss: 0.8276: 100%|██████████| 1/1 [00:00<00:00,  5.20it/s]
avg val loss: 0.7779: 100%|██████████| 1/1 [00:00<00:00, 27.85it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.8019 -> 0.7779

Starting epoch 24 / 50


avg train loss: 0.8052: 100%|██████████| 1/1 [00:00<00:00,  5.24it/s]
avg val loss: 0.7541: 100%|██████████| 1/1 [00:00<00:00, 37.14it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.7779 -> 0.7541

Starting epoch 25 / 50


avg train loss: 0.7831: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]
avg val loss: 0.7306: 100%|██████████| 1/1 [00:00<00:00, 38.57it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.7541 -> 0.7306

Starting epoch 26 / 50


avg train loss: 0.7612: 100%|██████████| 1/1 [00:00<00:00,  4.07it/s]
avg val loss: 0.7074: 100%|██████████| 1/1 [00:00<00:00, 40.00it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.7306 -> 0.7074

Starting epoch 27 / 50


avg train loss: 0.7397: 100%|██████████| 1/1 [00:00<00:00,  5.98it/s]
avg val loss: 0.6845: 100%|██████████| 1/1 [00:00<00:00, 41.78it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.7074 -> 0.6845

Starting epoch 28 / 50


avg train loss: 0.7186: 100%|██████████| 1/1 [00:00<00:00,  5.44it/s]
avg val loss: 0.6620: 100%|██████████| 1/1 [00:00<00:00, 43.47it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.6845 -> 0.6620

Starting epoch 29 / 50


avg train loss: 0.6978: 100%|██████████| 1/1 [00:00<00:00,  4.81it/s]
avg val loss: 0.6400: 100%|██████████| 1/1 [00:00<00:00, 50.07it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.6620 -> 0.6400

Starting epoch 30 / 50


avg train loss: 0.6775: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]
avg val loss: 0.6184: 100%|██████████| 1/1 [00:00<00:00, 40.83it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.6400 -> 0.6184

Starting epoch 31 / 50


avg train loss: 0.6577: 100%|██████████| 1/1 [00:00<00:00,  5.37it/s]
avg val loss: 0.5973: 100%|██████████| 1/1 [00:00<00:00, 41.78it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.6184 -> 0.5973

Starting epoch 32 / 50


avg train loss: 0.6384: 100%|██████████| 1/1 [00:00<00:00,  5.89it/s]
avg val loss: 0.5767: 100%|██████████| 1/1 [00:00<00:00, 45.44it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5973 -> 0.5767

Starting epoch 33 / 50


avg train loss: 0.6196: 100%|██████████| 1/1 [00:00<00:00,  6.10it/s]
avg val loss: 0.5567: 100%|██████████| 1/1 [00:00<00:00, 54.40it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5767 -> 0.5567

Starting epoch 34 / 50


avg train loss: 0.6014: 100%|██████████| 1/1 [00:00<00:00,  5.79it/s]
avg val loss: 0.5372: 100%|██████████| 1/1 [00:00<00:00, 48.79it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5567 -> 0.5372

Starting epoch 35 / 50


avg train loss: 0.5837: 100%|██████████| 1/1 [00:00<00:00,  5.41it/s]
avg val loss: 0.5184: 100%|██████████| 1/1 [00:00<00:00, 45.58it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5372 -> 0.5184

Starting epoch 36 / 50


avg train loss: 0.5667: 100%|██████████| 1/1 [00:00<00:00,  5.01it/s]
avg val loss: 0.5002: 100%|██████████| 1/1 [00:00<00:00, 41.45it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5184 -> 0.5002

Starting epoch 37 / 50


avg train loss: 0.5503: 100%|██████████| 1/1 [00:00<00:00,  6.28it/s]
avg val loss: 0.4826: 100%|██████████| 1/1 [00:00<00:00, 45.58it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.5002 -> 0.4826

Starting epoch 38 / 50


avg train loss: 0.5345: 100%|██████████| 1/1 [00:00<00:00,  6.09it/s]
avg val loss: 0.4657: 100%|██████████| 1/1 [00:00<00:00, 52.78it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4826 -> 0.4657

Starting epoch 39 / 50


avg train loss: 0.5194: 100%|██████████| 1/1 [00:00<00:00,  6.80it/s]
avg val loss: 0.4495: 100%|██████████| 1/1 [00:00<00:00, 58.98it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4657 -> 0.4495

Starting epoch 40 / 50


avg train loss: 0.5050: 100%|██████████| 1/1 [00:00<00:00,  6.91it/s]
avg val loss: 0.4339: 100%|██████████| 1/1 [00:00<00:00, 52.88it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4495 -> 0.4339

Starting epoch 41 / 50


avg train loss: 0.4912: 100%|██████████| 1/1 [00:00<00:00,  6.75it/s]
avg val loss: 0.4190: 100%|██████████| 1/1 [00:00<00:00, 58.98it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4339 -> 0.4190

Starting epoch 42 / 50


avg train loss: 0.4781: 100%|██████████| 1/1 [00:00<00:00,  6.62it/s]
avg val loss: 0.4048: 100%|██████████| 1/1 [00:00<00:00, 47.75it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4190 -> 0.4048

Starting epoch 43 / 50


avg train loss: 0.4657: 100%|██████████| 1/1 [00:00<00:00,  6.99it/s]
avg val loss: 0.3913: 100%|██████████| 1/1 [00:00<00:00, 50.13it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.4048 -> 0.3913

Starting epoch 44 / 50


avg train loss: 0.4539: 100%|██████████| 1/1 [00:00<00:00,  6.88it/s]
avg val loss: 0.3785: 100%|██████████| 1/1 [00:00<00:00, 58.98it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3913 -> 0.3785

Starting epoch 45 / 50


avg train loss: 0.4429: 100%|██████████| 1/1 [00:00<00:00,  7.18it/s]
avg val loss: 0.3664: 100%|██████████| 1/1 [00:00<00:00, 58.98it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3785 -> 0.3664

Starting epoch 46 / 50


avg train loss: 0.4324: 100%|██████████| 1/1 [00:00<00:00,  7.01it/s]
avg val loss: 0.3549: 100%|██████████| 1/1 [00:00<00:00, 47.75it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3664 -> 0.3549

Starting epoch 47 / 50


avg train loss: 0.4226: 100%|██████████| 1/1 [00:00<00:00,  6.23it/s]
avg val loss: 0.3440: 100%|██████████| 1/1 [00:00<00:00, 45.58it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3549 -> 0.3440

Starting epoch 48 / 50


avg train loss: 0.4134: 100%|██████████| 1/1 [00:00<00:00,  5.48it/s]
avg val loss: 0.3338: 100%|██████████| 1/1 [00:00<00:00, 43.59it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3440 -> 0.3338

Starting epoch 49 / 50


avg train loss: 0.4048: 100%|██████████| 1/1 [00:00<00:00,  5.90it/s]
avg val loss: 0.3241: 100%|██████████| 1/1 [00:00<00:00, 47.75it/s]


---  Accuracy : 0.9642857313156128 


best val loss 0.3338 -> 0.3241

Starting epoch 50 / 50


avg train loss: 0.3968: 100%|██████████| 1/1 [00:00<00:00,  6.31it/s]
avg val loss: 0.3151: 100%|██████████| 1/1 [00:00<00:00, 50.13it/s]

---  Accuracy : 0.9642857313156128 


best val loss 0.3241 -> 0.3151





SingleNet(
  (temporalConv): Conv1d(1, 8, kernel_size=(11, 1), stride=(1,))
  (spatialConv): Conv1d(8, 8, kernel_size=(12, 1), stride=(1,))
  (maxPooling): MaxPool2d(kernel_size=(1, 3), stride=(1, 3), padding=0, dilation=1, ceil_mode=False)
  (firstConv): Conv1d(8, 8, kernel_size=(1, 1), stride=(1,))
  (secondConv): Conv1d(8, 33, kernel_size=(1, 1), stride=(1,))
  (thirdConv): Conv1d(33, 33, kernel_size=(1, 1), stride=(1,))
  (maxPooling2): MaxPool2d(kernel_size=(1, 3), stride=(1, 3), padding=0, dilation=1, ceil_mode=False)
  (classify): Conv1d(41, 41, kernel_size=(1, 11), stride=(1,))
  (maxPooling3): MaxPool2d(kernel_size=(1, 3), stride=(1, 3), padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (elu): ELU(alpha=1.0)
  (fully): Linear(in_features=246, out_features=4, bias=True)
)

In [9]:
# Test model works:

n_samples_test = 10
y_test = torch.randint(0, 2, (n_samples_test,))
y_pred = model.forward(torch.randn(n_samples_test, 1, *epochs_data.shape[1:]))
output = F.nll_loss(y_pred, y_test)
_, top_class = y_pred.topk(1, dim=1)


In [10]:
print(top_class)

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]])
