In [9]:
import torch
from torch import nn
from torch.utils.data import Dataset
import transformers
from transformers import ViTModel
import numpy as np
import pandas as pd
import tqdm
import matplotlib
import math
import random

Set Seed

In [10]:
SEED = 2
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

Dataset Loader

In [29]:
class EEGEyeNetDataset(Dataset):
    def __init__(self, data_file,transpose = True):
        self.data_file = data_file
        print('loading data...')
        with np.load(self.data_file) as f: # Load the data array
            self.trainX = f['EEG']
            self.trainY = f['labels']
        print(self.trainY)
        if transpose:
            self.trainX = np.transpose(self.trainX, (0,2,1))[:,np.newaxis,:,:]

    def __getitem__(self, index):
        # Read a single sample of data from the data array
        X = torch.from_numpy(self.trainX[index]).float()
        y = torch.from_numpy(self.trainY[index,1:3]).float()
        # Return the tensor data
        return (X,y,index)

    def __len__(self):
        # Compute the number of samples in the data array
        return len(self.trainX)

Model Class

In [20]:
class EEGViT_pretrained(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1, 
            out_channels=256,
            kernel_size=(1, 36),
            stride=(1, 36),
            padding=(0,2),
            bias=False
        )
        self.batchnorm1 = nn.BatchNorm2d(256, False)
        model_name = "google/vit-base-patch16-224"
        config = transformers.ViTConfig.from_pretrained(model_name)
        config.update({'num_channels': 256})
        config.update({'image_size': (129,14)})
        config.update({'patch_size': (8,1)})

        model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
        model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(256, 768, kernel_size=(8, 1), stride=(8, 1), padding=(0,0), groups=256)
        model.classifier=torch.nn.Sequential(torch.nn.Linear(768,1000,bias=True),
                                     torch.nn.Dropout(p=0.1),
                                     torch.nn.Linear(1000,2,bias=True))
        self.ViT = model
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.batchnorm1(x)
        x=self.ViT.forward(x).logits
        
        return x

Config

In [30]:
model = EEGViT_pretrained()
EEGEyeNet = EEGEyeNetDataset('./data/Position_task_with_dots_synchronised_min_5_perc.npz')
batch_size = 8
n_epoch = 15
learning_rate = 1e-4

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- vit.embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 256, 8, 1]) in the model instantiated
- vit.embeddings.position_embeddings: found shape torch.Size([1, 197, 768]) in the checkpoint and torch.Size([1, 225, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loading data...
[[  1.  408.1 315.1]
 [  1.  640.7 519.1]
 [  1.  404.2 118.8]
 ...
 [  9.   94.2 140.7]
 [  9.  165.4 528.9]
 [  9.  152.   81.2]]


In [31]:
print(EEGEyeNet.trainX.shape)
print(EEGEyeNet.trainY.shape)

(1073, 1, 129, 500)
(1073, 3)


In [35]:
EEGEyeNet.trainX[0]

array([[[  1.07926917,   6.18672376,  11.43926464, ...,   4.67739071,
           3.05745696,   1.30875489],
        [-48.59606099, -44.37909193, -40.18544235, ...,  14.79268803,
          14.67349957,  14.03868004],
        [-37.82922673, -36.00171537, -34.60119476, ...,  24.99104847,
          25.77515744,  26.30749031],
        ...,
        [ -1.27373337,  -0.1342227 ,   0.85801707, ...,   3.84694447,
           3.0548663 ,   2.1669273 ],
        [  0.80119968,   0.74955063,   0.34021102, ...,   3.65751995,
           2.5850315 ,   1.26671835],
        [  9.31587291,   8.48757296,   7.60863267, ...,  -4.1388763 ,
          -3.57427646,  -3.28026298]]])