# Pytorch-Implementation for EEG-ConvTransformer which proposed in citation[1]
@Xin Zhang, SZU.
"""
### README
Here presents a demo for training and test
Before running, the visualized-image should be generated from EEG signals by run /preprocess/project2img.ipynb, read chapter 3.1 of citation for more details. Note that in this part there are some uncertain coding due to undisclosed details in citation [2]. It's welcome to help me to refine this repository.

The proposed method (Called EEG-ConvTransformer) of citation[1] is implemented in /model. It should be no problem.

### Ref
`[1] Bagchi S, Bathula D R. EEG-ConvTransformer for single-trial EEG-based visual stimulus classification[J]. Pattern Recognition, 2022, 129: 108757.`

`[2] Bashivan, et al. "Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016).`
"""

In [None]:
import numpy as np
import scipy.io as sio
import torch
from torch.utils.data import DataLoader,random_split
from data_load.dataset import EEGImagesDataset
from train_test import train_validate
from model.conv_transformer import ConvTransformer

torch.manual_seed(1234)
np.random.seed(1234)

import warnings
warnings.simplefilter("ignore")

Load data. The author[1] referenced Azimuthal Equidistant Projection[2] for EEG-Visualization. Their dataset have the same spacial shape:[w=32, h=32]. Since ref[2] shared their dataset and codes, so it's more lower-cost to use their dataset.

In [None]:

Images = sio.loadmat("sample_data/time_frames.mat")["img"]
print(np.shape(Images))
Label = (sio.loadmat("sample_data/labels.mat")["lab"]).astype(int)
print(np.shape(Label))

(2670, 3, 32, 32)
(2670, 7, 3, 32, 32)
(2670,)
(2670,)
Choose among the patient : [ 1  2  3  4  6  7  8  9 10 11 12 14 15]

In [None]:
choosen_patient = 9
train_part = 0.8
test_part = 0.2
batch_size = 32

In [None]:
EEG = EEGImagesDataset(label=Label, image=Images)
lengths = [int(len(EEG)*train_part+1), int(len(EEG)*test_part)]
Train, Test = random_split(EEG, lengths)
train_loader = DataLoader(Train, batch_size=batch_size)
test_loader = DataLoader(Test, batch_size=batch_size)

Define the ConvTransformer[1] model and perform training and validation.

In [None]:
model = ConvTransformer(num_classes=16, channels=8, num_heads=2, E=16, F=256, T=32, depth=2)
print('Begin Training for Patient '+str(choosen_patient))
res = train_validate(model, train_loader, test_loader, n_epoch=60, learning_rate=0.00001, print_epoch=5, opti='Adam')

Begin Training for Patient 9
[5, 100]	loss: 1.375	Accuracy : 0.296		val-loss: 1.376	val-Accuracy : 0.150
[10, 100]	loss: 1.362	Accuracy : 0.302		val-loss: 1.371	val-Accuracy : 0.225
[15, 100]	loss: 1.352	Accuracy : 0.302		val-loss: 1.368	val-Accuracy : 0.225
[20, 100]	loss: 1.342	Accuracy : 0.302		val-loss: 1.362	val-Accuracy : 0.225
[25, 100]	loss: 1.312	Accuracy : 0.302		val-loss: 1.335	val-Accuracy : 0.225
[30, 100]	loss: 1.191	Accuracy : 0.302		val-loss: 1.250	val-Accuracy : 0.225
[35, 100]	loss: 0.981	Accuracy : 0.586		val-loss: 1.105	val-Accuracy : 0.575
[40, 100]	loss: 0.836	Accuracy : 0.605		val-loss: 1.015	val-Accuracy : 0.650
[45, 100]	loss: 0.760	Accuracy : 0.611		val-loss: 1.008	val-Accuracy : 0.700
[50, 100]	loss: 0.677	Accuracy : 0.654		val-loss: 1.047	val-Accuracy : 0.725
[55, 100]	loss: 0.561	Accuracy : 0.753		val-loss: 1.120	val-Accuracy : 0.725
[60, 100]	loss: 0.421	Accuracy : 0.833		val-loss: 1.253	val-Accuracy : 0.800
[65, 100]	loss: 0.301	Accuracy : 0.895		val-loss: 1.419	val-Accuracy : 0.750
[70, 100]	loss: 0.212	Accuracy : 0.944		val-loss: 1.557	val-Accuracy : 0.825
[75, 100]	loss: 0.145	Accuracy : 0.969		val-loss: 1.810	val-Accuracy : 0.875
[80, 100]	loss: 0.096	Accuracy : 0.981		val-loss: 2.223	val-Accuracy : 0.875
[85, 100]	loss: 0.063	Accuracy : 0.994		val-loss: 2.621	val-Accuracy : 0.875
[90, 100]	loss: 0.043	Accuracy : 0.994		val-loss: 2.955	val-Accuracy : 0.900
[95, 100]	loss: 0.031	Accuracy : 0.994		val-loss: 3.245	val-Accuracy : 0.900
[100, 100]	loss: 0.023	Accuracy : 0.994		val-loss: 3.496	val-Accuracy : 0.900
Finished Training
 loss: 0.023	Accuracy : 0.994		val-loss: 3.496	val-Accuracy : 0.900