<a href="https://colab.research.google.com/github/Gooogr/Brain2Image/blob/master/sub_notebooks/EEG_classificator_custom.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

EEG signals classificator based on Perceive Lab dataset<br>
Dataset:  http://www.perceivelab.com/dataset/EEG%20Data%20for%20Visual%20Classification

In [None]:
# Imports
import sys
import os
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
from scipy.fftpack import fft, rfft, fftfreq, irfft, ifft, rfftfreq
import numpy as np

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
! ln -s "/content/drive/My Drive" "/content/mydrive"

Link to code:<br>
https://arxiv.org/pdf/1812.07697.pdf

Link to percivelab files:<br>
http://perceive.dieei.unict.it/files/

Link to LSTM script from the article:<br>
http://perceive.dieei.unict.it/files/cvpr_2017_eeg_encoder.py

In [None]:
# # Define options
# import argparse
# parser = argparse.ArgumentParser(description="Template")
# # Dataset options
# parser.add_argument('-ed', '--eeg-dataset', default="datasets/eeg_signals_band_all.pth", help="EEG dataset path")
# parser.add_argument('-sp', '--splits-path', default="datasets/splits_by_image.pth", help="splits path")
# parser.add_argument('-sn', '--split-num', default=0, type=int, help="split number")
# # Model options
# parser.add_argument('-ll', '--lstm-layers', default=1, type=int, help="LSTM layers")
# parser.add_argument('-ls', '--lstm-size', default=128, type=int, help="LSTM hidden size")
# parser.add_argument('-es', '--embedding-size', default=128, type=int, help="embedding size")
# parser.add_argument('-nc', '--num-classes', default=40, type=int, help="num classes")

# # Filtering options
# parser.add_argument('-filt', '--filtering', default=True,  help="filter your data")

# # Training options
# parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size")
# parser.add_argument('-o', '--optim', default="Adam", help="optimizer")
# parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, help="learning rate")
# parser.add_argument('-lrdb', '--learning-rate-decay-by', default=0.5, type=float, help="learning rate decay factor")
# parser.add_argument('-lrde', '--learning-rate-decay-every', default=10, type=int, help="learning rate decay period")
# parser.add_argument('-e', '--epochs', default=300, type=int, help="training epochs")
# parser.add_argument('-dw', '--data-workers', default=16, type=int, help="data loading workers")
# # Backend options
# parser.add_argument('--no-cuda', default=False, help="disable CUDA", action="store_true")

# # Parse arguments
# opt = parser.parse_args()


### Prepare dataset

In [None]:
DATASET_PATH = '/content/mydrive/EEG2Image_research/Datasets/perceive_lab/eeg_signals_128_sequential_band_all_with_mean_std.pth'
SPLIT_PATH = '/content/mydrive/EEG2Image_research/Datasets/perceive_lab/block_splits_by_image.pth'

In [None]:
# !python3 cvpr_2017_eeg_encoder.py -ed $DATASET_PATH -sp $SPLIT_PATH

In [None]:
# Dataset class
class EEGDataset:
    
    # Constructor
    def __init__(self, eeg_signals_path):
        # Load EEG signals
        loaded = torch.load(eeg_signals_path)
        self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.means = loaded["means"]
        self.stddevs = loaded["stddevs"]
        # Compute size
        self.size = len(self.data)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        eeg = ((self.data[i]["eeg"].float() - self.means)/self.stddevs) #.t() # CxT
        # Check filtering
        # Uses global opt
        if opt.filtering:
            # Time axis
            N = eeg.size(1)
            T = 1.0/1000.0
            time = np.linspace(0.0, N*T, N)
            # Frequency axis
            w = rfftfreq(N, T)
            # FFT
            eeg = eeg.numpy()
            eeg_fft = rfft(eeg)
            # Filter
            eeg_fft[:,w < 15] = 0
            eeg_fft[:,np.bitwise_and(w > 47, w < 53)] = 0
            eeg_fft[:,w > 71] = 0
            eeg = irfft(eeg_fft)
            # Convert to tensor
            eeg = torch.tensor(eeg)
        # Transpose to TxC
        eeg = eeg.t()
        eeg = eeg[20:460,:]
        # Get label
        label = self.data[i]["label"]
        # Return
        return eeg, label

# Splitter class
class Splitter:

    def __init__(self, dataset, split_path, split_num=0, split_name="train"):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        loaded = torch.load(split_path)
        self.split_idx = loaded["splits"][split_num][split_name]
        # Filter data
        self.split_idx = [i for i in self.split_idx if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600]
        # Compute size
        self.size = len(self.split_idx)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Get sample from dataset
        eeg, label = self.dataset[self.split_idx[i]]
        # Return
        return eeg, label

In [None]:
# Load dataset
dataset = EEGDataset(opt.eeg_dataset)
# Create loaders
loaders = {split: DataLoader(Splitter(dataset, split_path = opt.splits_path, split_num = opt.split_num, split_name = split), batch_size = opt.batch_size, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}

### Setting up model

In [None]:
# Define model
class Model(nn.Module):

    def __init__(self, input_size, lstm_size, lstm_layers, embedding_size, num_classes):
        # Call parent
        super().__init__()
        # Define parameters
        self.input_size = input_size
        self.lstm_size = lstm_size
        self.lstm_layers = lstm_layers
        self.embedding_size = embedding_size
        self.num_classes = num_classes
        # Define internal modules
        self.lstm = nn.LSTM(input_size, lstm_size, num_layers=lstm_layers, batch_first=True)
        self.embedding = nn.Linear(lstm_size, embedding_size)
        self.classifier = nn.Linear(embedding_size, num_classes)

    def forward(self, x):
        # Prepare LSTM initiale state
        batch_size = x.size(0)
        # Forward LSTM and get final state
        x = self.lstm(x)[0][:,-1,:]
        # Forward embedding
        x = F.relu(self.embedding(x))
        # Forward classifier
        x = self.classifier(x)
        return x

In [None]:
model = Model(128, opt.lstm_size, opt.lstm_layers, opt.embedding_size, opt.num_classes)
optimizer = getattr(torch.optim, opt.optim)(model.parameters(), lr = opt.learning_rate)

In [None]:
# Setup CUDA
if not opt.no_cuda:
    model.cuda()
    print("Copied to CUDA")

# Start training
for epoch in range(1, opt.epochs+1):
    # Initialize loss/accuracy variables
    losses = {"train": 0, "val": 0, "test": 0}
    accuracies = {"train": 0, "val": 0, "test": 0}
    counts = {"train": 0, "val": 0, "test": 0}
    # Adjust learning rate for SGD
    if opt.optim == "SGD":
        lr = opt.learning_rate * (opt.learning_rate_decay_by ** (epoch // opt.learning_rate_decay_every))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    # Process each split
    for split in ("train", "val", "test"):
        # Set network mode
        if split == "train":
            model.train()
            torch.set_grad_enabled(True)
        else:
            model.eval()
            torch.set_grad_enabled(False)
        # Process all split batches
        for i, (input, target) in enumerate(loaders[split]):
            # Check CUDA
            if not opt.no_cuda:
                input = input.cuda(async = True)
                target = target.cuda(async = True)
            # Forward
            output = model(input)
            loss = F.cross_entropy(output, target)
            losses[split] += loss.item()
            # Compute accuracy
            _,pred = output.data.max(1)
            correct = pred.eq(target.data).sum().item()
            accuracy = correct/input.data.size(0)
            accuracies[split] += accuracy
            counts[split] += 1
            # Backward and optimize
            if split == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    # Print info at the end of the epoch
    print("Epoch {0}: TrL={1:.4f}, TrA={2:.4f}, VL={3:.4f}, VA={4:.4f}, TeL={5:.4f}, TeA={6:.4f}".format(epoch,
                                                                                                         losses["train"]/counts["train"],
                                                                                                         accuracies["train"]/counts["train"],
                                                                                                         losses["val"]/counts["val"],
                                                                                                         accuracies["val"]/counts["val"],
                                                                                                         losses["test"]/counts["test"],
                                                                                                         accuracies["test"]/counts["test"]))
