# Pytorch-Implementation for EEG-ConvTransformer which proposed in citation[1]
"""
Created by Xin Zhang, SZU.
### README
Here presents a demo for training and test
Before running, the visualized-image should be generated from EEG signals by run /preprocess/project2img.ipynb, read chapter 3.1 of citation for more details. Note that in this part there are some uncertain coding due to undisclosed details in citation [1]. It's welcome to help me to refine this repository.

The proposed method (Called EEG-ConvTransformer) of citation[1] is implemented in /model. It should be no problem.

### Ref
`[1] Bagchi S, Bathula D R. EEG-ConvTransformer for single-trial EEG-based visual stimulus classification[J]. Pattern Recognition, 2022, 129: 108757.`

`[2] Bashivan, et al. "Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016).`
"""

In [8]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from data_load.dataset import EEGImagesDataset
from model.conv_transformer import ConvTransformer
torch.manual_seed(1234)
np.random.seed(1234)

Load data.
First, download the dataset from https://purl.stanford.edu/bq914sc3730
The author[1] referenced Azimuthal Equidistant Projection[2] for EEG-Visualization.

In [9]:
batch_size = 64
learning_rate = 0.002
epochs = 15

load the dataset

In [27]:
dataset = EEGImagesDataset(path='E:/Datasets/Stanford_digital_repository/img_pkl')
total_x = dataset.__len__()
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=6, shuffle=True)

Define the ConvTransformer[1] model and perform training and validation.

In [28]:
model = ConvTransformer(num_classes=6, channels=8, num_heads=2, E=16, F=256, T=32, depth=2).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, betas=(0.9, 0.98), eps=1e-9)

In [None]:
if __name__ == '__main__':
    step = 0
    global_step = 0
    for epoch in range(epochs + 1):
        for x, y in loader:
            x = x.cuda()
            y = y.cuda()
            model.train()
            optimizer.zero_grad()
            y_ = model(x)
            loss = torch.nn.functional.cross_entropy(y_, y)
            loss.backward()
            optimizer.step()

            step += 1
            global_step += 1
            if step % 50 == 0:
                corrects = (torch.argmax(y_, dim=1).data == y.data)
                accuracy = corrects.cpu().int().sum().numpy() / batch_size
                print('epoch:{}/{} step:{}/{} global_step:{} '
                      'loss={:.5f} acc={:.3f}'.format(epoch, epochs, step, int(total_x / batch_size), global_step, loss,
                                                      accuracy))
        step = 0
    print('done')