# Disentangling time-variant and time-invariant factors for improved classification of EEG signals

## Abstract: 

Two main challenges in classifying stimuli using EEG signals are 1) Low signal-to-noise ratio. This is due to the time-variant factors appearing in the course of measurements. These time-varying factors can be electrical surroundings, muscle activity, eye movements of blinks, etc. 2) Variabilities between individual subjects.

To this end, we propose a novel architecture based on the recent development of disentangled representation and probabilistic sequential modeling. The underlying architecture is a Conv1dLSTM, that utilizes **only the invariant factors** for classification. We hoped that disentangling time-varying and time-invariant dynamics apparent in the sequence of EEG data, increase the classification accuracy. Our experiment using MIIR dataset shows we can achieve accuracy of 21.67% in test time, verified using the outer 9-fold cross-validation performed across subjects as in [1](http://bib.sebastianstober.de/icassp2017.pdf). 

# Dataset

The Mirr dataset contains 64 EEG channels, 9 subjects and 12 audio stimuli for 540 trails. Measurements sequences of are length 3518. They have been normalized to zero-mean and range[-1,1].
Therefore no *normalization/zfiltering* is necessary. 

# ML considerations

The seq length of 3518 is way longer than the 250-300 steps used in practice for LSTM. We, therefore, first apply a Conv1d with a kernel size of 320 and stride 160 to reduce the length of sequences to 20. We used factored disentangled representation for sequential data, described in the paper [2](https://arxiv.org/pdf/1803.02991.pdf). Using similar techniques presented in [3](https://openreview.net/pdf?id=Sy2fzU9gl), [2](https://arxiv.org/pdf/1803.02991.pdf) derives time-variant encodings $\mathcal{z}$ and time-invariant features $\mathcal{f}$ for sequential data. Our architecture has 2 main differences. 1) First, we are concerned with classification rather than data generation. The decoder is, therefore, is replaced with a classifier and the reconstruction loss is replaced with CrossEntropy loss. 2) Most importantly, unlike [2](https://arxiv.org/pdf/1803.02991.pdf) where ($\mathcal{z}$, $\mathcal{f}$) is passed to the decoder, we only use $\mathcal{f}$ to output classifications. This is to ignore time-variant factor/noises appearing in the course of experiments. 

Also, since small amount of trial data is available, we use batchsize = 1.

# Training and Evaluation scheme 
Verification is being conducted using the outer 9-fold cross-validation performed across subjects as in [1](http://bib.sebastianstober.de/icassp2017.pdf).
A random subject is excluded from the training and the rest of the data get used for the training. The data for the excluded subject then gets used for validation.

In [20]:
import os
import random

import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from tqdm import *
import h5py

from model import *
import config

random.seed(0)
torch.manual_seed(0)
np.random.seed(0)


class Miir(data.Dataset):
    def __init__(self, data_path=config.DATA_PATH, train=True):
        h5 = h5py.File(data_path, 'r')
        self.train = train

        features = h5['features']
        targets = h5['targets']
        subjects = h5['subjects']

        self.test_subject_id = random.randint(0,8)
        train_indxs = [i for i, e in enumerate(subjects) if e != self.test_subject_id]

        self.train_features = [e for i, e in enumerate(features) if i in train_indxs]
        self.test_features = [e for i, e in enumerate(features) if i not in train_indxs]

        self.train_targets = [e for i, e in enumerate(targets) if i in train_indxs]
        self.test_targets = [e for i, e in enumerate(targets) if i not in train_indxs]

        self.train_subjects = [e for i, e in enumerate(subjects) if i in train_indxs]
        self.test_subjects = [e for i, e in enumerate(subjects) if i not in train_indxs]

        self.train_size = len(self.train_features)
        self.test_size = len(self.test_features)

    def __len__(self):
        if self.train:
            return self.train_size
        else:
            return self.test_size

    def __getitem__(self, idx):
        if self.train:
            return self.train_features[idx], self.train_targets[idx], self.train_subjects[idx]
        else:
            return self.test_features[idx], self.test_targets[idx], self.test_subjects[idx]


def loss_fn(target, pred_target, f_mean, f_logvar, z_post_mean, z_post_logvar, z_prior_mean, z_prior_logvar):
    """
    :param target:
    :param pred_target
    :param f_mean:
    :param f_logvar:
    :param z_post_mean:
    :param z_post_logvar:
    :param z_prior_mean:
    :param z_prior_logvar:
    :return:
    Loss function consists of 3 parts, Cross Entropy of the predicted targes and the target, the KL divergence of f,
    and the sum over the KL divergence of each z_t, with the sum divided by batch_size.
    Loss = {CrossEntropy(pred_target, target) + KL of f + sum(KL of z_t)}/batch_size
    Prior of f is a spherical zero_mean unit variance Gaussian and the prior for each z_t is a Gaussian whose
    mean and variance are given by LSTM.
    """
    batch_size = target.size(0)
    cross_entropy = F.cross_entropy(pred_target, target)
    kld_f = - 0.5 * torch.sum(1+f_logvar - torch.pow(f_mean,2) - torch.exp(f_logvar))
    z_post_var = torch.exp(z_post_logvar)
    z_prior_var = torch.exp(z_prior_logvar)
    kld_z = 0.5 * torch.mean(z_prior_logvar - z_post_logvar + ((z_post_var + torch.pow(z_post_mean - z_prior_mean, 2))
                                                               / z_prior_var) - 1)

    return (cross_entropy + (kld_f + kld_z)) / batch_size, kld_f / batch_size, kld_z / batch_size,\
           cross_entropy/batch_size


def save_model(model, optim, epoch, path):
    torch.save({
        'epoch': epoch+1,
        'state_dict': model.state_dict(),
        'opimizer': optim.state_dict()}, path)


def check_accuracy(model, test):
    model.eval()
    total = 0
    correct_target = 0
    with torch.no_grad():
        for item in test:
            features, target, subject = item
            target = torch.argmax(target, dim=1) # one-hot back to int
            *_, pred_target = model(features)
            _, pred_target = torch.max(pred_target.data, 1)
            total += target.size(0)
            correct_target+=(pred_target==target).sum().item()
    model.train()
    return correct_target/total


def train_classifier(model, optim, dataset, epochs, path, test, start = 0):
    model.train()
    for epoch in range(start, epochs):
        losses = []
        kld_fs = []
        kld_zs = []
        cross_entropies = []
        
        for i, item in enumerate(dataset,1):
            features, target, subject = item
            target = torch.argmax(target, dim=1)  # one hot back to int
            optim.zero_grad()
            f_mean, f_logvar, f, z_post_mean, z_post_logvar, z, z_prior_mean,\
            z_prior_logvar, pred_target = model(features)
            loss, kld_f, kld_z, cross_entropy = loss_fn(target, pred_target, f_mean, f_logvar,
                                                       z_post_mean, z_post_logvar, z_prior_mean, z_prior_logvar)
            loss.backward()
            optim.step()
            losses.append(loss.item())
            kld_fs.append(kld_f.item())
            kld_zs.append(kld_z.item())
            cross_entropies.append(cross_entropy.item())

        test_accuracy = check_accuracy(model, test)
        meanloss = np.mean(losses)
        meanf = np.mean(kld_fs)
        meanz = np.mean(kld_zs)
        mean_cross_entropies = np.mean(cross_entropies)
        if epoch%20==0: #print out result every 20 epochs
            print("Epoch {} : Average Loss: {} KL of f : {} KL of z : {} "
                  "Cross Entropy: {} Test Accuracy: {}".format(epoch + 1, meanloss, meanf, meanz, mean_cross_entropies,
                                                               test_accuracy))
        save_model(model, optim, epoch, path)


if __name__=='__main__':
    model = DisentangledEEG(factorized=True, nonlinearity=True)
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    train_data = Miir(config.DATA_PATH, True)
    test_data = Miir(config.DATA_PATH, False)
    loader = data.DataLoader(train_data, batch_size=1, shuffle=True, num_workers=1)
    loader_test = data.DataLoader(test_data, batch_size=60, shuffle=True, num_workers=4)
    train_classifier(model=model, optim=optim, dataset=loader, epochs=200,
                     path='./checkpoint_disentangled_classifier.pth', test=loader_test)

Epoch 1 : Average Loss: 3.520056616763274 KL of f : 0.07574303398529689 KL of z : 0.9584266122741004 Cross Entropy: 2.4858869741360348 Test Accuracy: 0.13333333333333333
Epoch 21 : Average Loss: 1.8140081226825715 KL of f : 0.2328883560995261 KL of z : 0.005657989709288813 Cross Entropy: 1.5754617758095264 Test Accuracy: 0.15
Epoch 41 : Average Loss: 1.30403537551562 KL of f : 0.20091642954697211 KL of z : 0.0011133064908790402 Cross Entropy: 1.1020056409140428 Test Accuracy: 0.2
Epoch 61 : Average Loss: 1.1021551149586837 KL of f : 0.11758367288857699 KL of z : 0.0002490390420462063 Cross Entropy: 0.9843224037438632 Test Accuracy: 0.16666666666666666
Epoch 81 : Average Loss: 1.0411095894873141 KL of f : 0.08546531063814958 KL of z : 5.902535857937134e-05 Cross Entropy: 0.9555852508793274 Test Accuracy: 0.2
Epoch 101 : Average Loss: 0.999629090850552 KL of f : 0.059354947817822294 KL of z : 1.890480166556093e-05 Cross Entropy: 0.9402552363773187 Test Accuracy: 0.18333333333333332
Epoch