In [1]:
!pip install pytorch-tcn

Collecting pytorch-tcn
  Downloading pytorch_tcn-1.2.3-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch-tcn)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch-tcn)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->pytorch-tcn)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch-tcn)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch-tcn)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->pytorch-tcn)
  Downloading nvidia

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np
import math

### Helper Functions

In [3]:
import math
import numpy as np

def split(ids, train, val, test):
    # proportions of train, val, test
    assert (train+val+test == 1)
    IDs = np.unique(ids)
    num_ids = len(IDs)

    # priority given to the test/val sets
    test_split = math.ceil(test * num_ids)
    val_split = math.ceil(val * num_ids)
    train_split = num_ids - val_split - test_split

    train = np.where(np.isin(ids, IDs[:train_split]))[0]
    val = np.where(np.isin(ids, IDs[train_split:train_split+val_split]))[0]
    test = np.where(np.isin(ids, IDs[train_split+val_split:]))[0]

    return train, val, test

### Connect to Drive

In [4]:
import os
from google.colab import drive

# Step 1: Mount Google Drive
drive.mount('/content/drive')

# Step 2: Define file path
file_path = "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz"

# Step 3: Create the folder if it doesn't exist
os.makedirs(os.path.dirname(file_path), exist_ok=True)

# Step 4: Check if file exists, if not, download it
if not os.path.exists(file_path):
    print("File not found. Downloading...")
    !wget -O "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz" "https://osf.io/download/ge87t/"
else:
    print("File already exists at:", file_path)

Mounted at /content/drive
File already exists at: /content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz


### Dataset Loader

In [5]:
from torch.utils.data import Dataset
import torch
import numpy as np

class EEGEyeNetDataset(Dataset):
        def __init__(self, data_file, transpose=True):
                self.data_file = data_file
                print('loading data...')
                with np.load(self.data_file) as f:  # Load the data array
                        self.trainX = f['EEG']
                        self.trainY = f['labels']
                # Filter data where y[:,1] is between 0 and 800 and y[:,2] is between 0 and 600
                valid_indices = (self.trainY[:, 1] >= 0) & (self.trainY[:, 1] <= 800) & \
                                        (self.trainY[:, 2] >= 0) & (self.trainY[:, 2] <= 600)
                self.trainX = self.trainX[valid_indices]
                self.trainY = self.trainY[valid_indices]
                if transpose:
                        self.trainX = np.transpose(self.trainX, (0, 2, 1))[:, np.newaxis, :, :]
                print(self.trainY)

        def __getitem__(self, index):
                # Read a single sample of data from the data array
                X = torch.from_numpy(self.trainX[index]).float()
                y = torch.from_numpy(self.trainY[index,1:3]).float()
                # Return the tensor data
                return (X,y,index)

        def __len__(self):
                # Compute the number of samples in the data array
                return len(self.trainX)

### Model

In [6]:
import torch
from torch import nn
from pytorch_tcn import TCN
import transformers
from transformers import ViTModel

class EEGVIT_TCN(nn.Module):
    def __init__(self):
        super().__init__()

        # TCN layer
        self.tcn = TCN(
            num_inputs=129,
            num_channels=[64, 128, 256],  # for three layers
            kernel_size=3,
            dropout=0.75,
            causal=True,
            use_norm='weight_norm',
            activation='relu',
            kernel_initializer='xavier_uniform'
        )

        # Convolutional layers with batch normalization
        self.conv1 = torch.nn.Conv2d(1, 256, kernel_size=(1, 36), stride=(1, 36), padding=(0,2))
        self.bn1 = nn.BatchNorm2d(256)
        self.conv2 = torch.nn.Conv2d(256, 768, kernel_size=(256, 1), stride=(256, 1), padding=(0,0))
        self.bn2 = nn.BatchNorm2d(768)

        self.relu = nn.ReLU()

        # ViT configuration
        model_name = "google/vit-base-patch16-224"
        config = transformers.ViTConfig.from_pretrained(model_name)
        config.update({'num_channels': 768, 'image_size': (1, 14), 'patch_size': (1, 1)})

        model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
        model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1), padding=(0,0))
        model.classifier = torch.nn.Sequential(
            torch.nn.Linear(768, 1000, bias=True),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(1000, 2, bias=True)
        )
        self.ViT = model

    def forward(self, x):
        x = self.tcn(x.squeeze(1))  # Adjust for time dimension

        # Reshape and apply convolutions and batch normalization
        x = x.view(x.size(0), 1, x.size(1), x.size(2))

        x = self.conv1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.relu(x)

        # Pass through ViT
        x = self.ViT(x).logits
        return x

### Config

In [7]:
model = EEGVIT_TCN()
EEGEyeNet = EEGEyeNetDataset(file_path)
batch_size = 64
n_epoch = 15
learning_rate = 1e-4

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- vit.embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 768, 1, 1]) in the model instantiated
- vit.embeddings.position_embeddings: found shape torch.Size([1, 197, 768]) in the checkpoint and torch.Size([1, 15, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loading data...
[[  1.  408.1 315.1]
 [  1.  640.7 519.1]
 [  1.  404.2 118.8]
 ...
 [177.  115.5 306.1]
 [177.  732.  310.3]
 [177.  632.2 353.6]]


In [8]:
import sys

def train(model, optimizer, scheduler = None):
    '''
        model: model to train
        optimizer: optimizer to update weights
        scheduler: scheduling learning rate, used when finetuning pretrained models
    '''
    torch.cuda.empty_cache()
    train_indices, val_indices, test_indices = split(EEGEyeNet.trainY[:,0],0.7,0.15,0.15)  # indices for the training set
    print('create dataloader...')
    criterion = nn.MSELoss()

    train = Subset(EEGEyeNet,indices=train_indices)
    val = Subset(EEGEyeNet,indices=val_indices)
    test = Subset(EEGEyeNet,indices=test_indices)

    train_loader = DataLoader(train, batch_size=batch_size)
    val_loader = DataLoader(val, batch_size=batch_size)
    test_loader = DataLoader(test, batch_size=batch_size)

    if torch.cuda.is_available():
        gpu_id = 0  # Change this to the desired GPU ID if you have multiple GPUs
        torch.cuda.set_device(gpu_id)
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cpu")
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)  # Wrap the model with DataParallel
    print("HI")

    model = model.to(device)
    criterion = criterion.to(device)

    # Initialize lists to store losses
    train_losses = []
    val_losses = []
    test_losses = []
    print('training...')
    # Train the model
    for epoch in range(n_epoch):
        model.train()
        epoch_train_loss = 0.0

        for i, (inputs, targets, index) in tqdm(enumerate(train_loader)):
            # Move the inputs and targets to the GPU (if available)
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Compute the outputs and loss for the current batch
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())

            # Compute the gradients and update the parameters
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()

            # Print the loss and accuracy for the current batch
            if i % 100 == 0:
                print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

        epoch_train_loss /= len(train_loader)
        train_losses.append(epoch_train_loss)

        # Evaluate the model on the validation set
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for inputs, targets, index in val_loader:
                # Move the inputs and targets to the GPU (if available)
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Compute the outputs and loss for the current batch
                outputs = model(inputs)
                # print(outputs)
                loss = criterion(outputs.squeeze(), targets.squeeze())
                val_loss += loss.item()


            val_loss /= len(val_loader)
            val_losses.append(val_loss)

            print(f"Epoch {epoch}, Val Loss: {val_loss}")

        with torch.no_grad():
            val_loss = 0.0
            for inputs, targets, index in test_loader:
                # Move the inputs and targets to the GPU (if available)
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Compute the outputs and loss for the current batch
                outputs = model(inputs)

                loss = criterion(outputs.squeeze(), targets.squeeze())
                val_loss += loss.item()

            val_loss /= len(test_loader)
            test_losses.append(val_loss)

            print(f"Epoch {epoch}, test Loss: {val_loss}")

        if scheduler is not None:
            scheduler.step()

In [9]:
train(model,optimizer=optimizer, scheduler=scheduler)

create dataloader...
HI
training...


2it [00:01,  2.06it/s]

Epoch 0, Batch 0, Loss: 169026.0


102it [00:08, 13.12it/s]

Epoch 0, Batch 100, Loss: 69001.71875


202it [00:16, 13.34it/s]

Epoch 0, Batch 200, Loss: 29923.1953125


236it [00:18, 12.44it/s]


Epoch 0, Val Loss: 27288.122967155614
Epoch 0, test Loss: 26864.89928002451


2it [00:00, 12.71it/s]

Epoch 1, Batch 0, Loss: 25510.51953125


102it [00:07, 13.48it/s]

Epoch 1, Batch 100, Loss: 27375.69921875


202it [00:15, 13.48it/s]

Epoch 1, Batch 200, Loss: 25715.236328125


236it [00:17, 13.32it/s]


Epoch 1, Val Loss: 23199.78419961735
Epoch 1, test Loss: 22153.415785845587


2it [00:00, 12.90it/s]

Epoch 2, Batch 0, Loss: 22146.173828125


102it [00:07, 13.38it/s]

Epoch 2, Batch 100, Loss: 23243.9765625


202it [00:15, 12.93it/s]

Epoch 2, Batch 200, Loss: 21491.5


236it [00:17, 13.12it/s]


Epoch 2, Val Loss: 19147.771643813776
Epoch 2, test Loss: 17086.710822610294


2it [00:00, 13.23it/s]

Epoch 3, Batch 0, Loss: 17067.37109375


102it [00:07, 13.48it/s]

Epoch 3, Batch 100, Loss: 15583.966796875


202it [00:15, 13.13it/s]

Epoch 3, Batch 200, Loss: 18129.04296875


236it [00:17, 13.12it/s]


Epoch 3, Val Loss: 17102.41015625
Epoch 3, test Loss: 14222.252412683823


2it [00:00, 13.61it/s]

Epoch 4, Batch 0, Loss: 13403.134765625


102it [00:07, 12.46it/s]

Epoch 4, Batch 100, Loss: 13847.7939453125


202it [00:15, 13.38it/s]

Epoch 4, Batch 200, Loss: 16503.42578125


236it [00:17, 13.15it/s]


Epoch 4, Val Loss: 16631.54095583546
Epoch 4, test Loss: 13167.391486672794


2it [00:00, 13.44it/s]

Epoch 5, Batch 0, Loss: 13365.615234375


102it [00:07, 13.31it/s]

Epoch 5, Batch 100, Loss: 13395.642578125


202it [00:15, 13.48it/s]

Epoch 5, Batch 200, Loss: 16671.5546875


236it [00:17, 13.32it/s]


Epoch 5, Val Loss: 16338.854153380102
Epoch 5, test Loss: 12593.027113970587


2it [00:00, 12.77it/s]

Epoch 6, Batch 0, Loss: 13093.22265625


102it [00:07, 13.01it/s]

Epoch 6, Batch 100, Loss: 11503.02734375


202it [00:15, 12.94it/s]

Epoch 6, Batch 200, Loss: 16704.107421875


236it [00:17, 13.14it/s]


Epoch 6, Val Loss: 15187.594148596938
Epoch 6, test Loss: 11238.759076286764


2it [00:00, 13.51it/s]

Epoch 7, Batch 0, Loss: 11789.21484375


102it [00:07, 13.31it/s]

Epoch 7, Batch 100, Loss: 11067.2451171875


202it [00:15, 13.25it/s]

Epoch 7, Batch 200, Loss: 15782.017578125


236it [00:17, 13.31it/s]


Epoch 7, Val Loss: 15158.212691326531
Epoch 7, test Loss: 11091.584989659927


2it [00:00, 13.63it/s]

Epoch 8, Batch 0, Loss: 11145.818359375


102it [00:07, 13.30it/s]

Epoch 8, Batch 100, Loss: 11268.068359375


202it [00:15, 13.26it/s]

Epoch 8, Batch 200, Loss: 15496.560546875


236it [00:17, 13.28it/s]


Epoch 8, Val Loss: 15194.325992506378
Epoch 8, test Loss: 11054.942038143383


2it [00:00, 12.94it/s]

Epoch 9, Batch 0, Loss: 11112.263671875


102it [00:07, 13.23it/s]

Epoch 9, Batch 100, Loss: 11029.853515625


202it [00:15, 13.19it/s]

Epoch 9, Batch 200, Loss: 15630.236328125


236it [00:17, 13.21it/s]


Epoch 9, Val Loss: 15067.530253507653
Epoch 9, test Loss: 10912.686762791054


2it [00:00, 12.87it/s]

Epoch 10, Batch 0, Loss: 11049.7109375


102it [00:07, 13.14it/s]

Epoch 10, Batch 100, Loss: 10405.171875


202it [00:15, 13.23it/s]

Epoch 10, Batch 200, Loss: 15509.16015625


236it [00:17, 13.18it/s]


Epoch 10, Val Loss: 14976.881257971938
Epoch 10, test Loss: 10777.307923560049


2it [00:00, 13.39it/s]

Epoch 11, Batch 0, Loss: 10874.568359375


102it [00:07, 13.19it/s]

Epoch 11, Batch 100, Loss: 10589.6591796875


202it [00:15, 13.28it/s]

Epoch 11, Batch 200, Loss: 15154.9306640625


236it [00:17, 13.20it/s]


Epoch 11, Val Loss: 14963.420738998724
Epoch 11, test Loss: 10746.743470435049


2it [00:00, 13.59it/s]

Epoch 12, Batch 0, Loss: 10570.0859375


102it [00:07, 13.21it/s]

Epoch 12, Batch 100, Loss: 9886.4736328125


202it [00:15, 13.46it/s]

Epoch 12, Batch 200, Loss: 14878.51953125


236it [00:17, 13.31it/s]


Epoch 12, Val Loss: 14468.624192841198
Epoch 12, test Loss: 10242.464546951593


2it [00:00, 12.90it/s]

Epoch 13, Batch 0, Loss: 11189.57421875


102it [00:07, 13.38it/s]

Epoch 13, Batch 100, Loss: 9924.8115234375


202it [00:15, 13.24it/s]

Epoch 13, Batch 200, Loss: 14278.2568359375


236it [00:17, 13.22it/s]


Epoch 13, Val Loss: 14402.554109534438
Epoch 13, test Loss: 10196.552198223038


2it [00:00, 13.39it/s]

Epoch 14, Batch 0, Loss: 11187.2763671875


102it [00:07, 13.40it/s]

Epoch 14, Batch 100, Loss: 9950.623046875


202it [00:15, 13.21it/s]

Epoch 14, Batch 200, Loss: 15236.7724609375


236it [00:17, 13.13it/s]


Epoch 14, Val Loss: 14390.121801259565
Epoch 14, test Loss: 10186.32502297794


In [10]:
# Save the final model
torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), "/content/drive/MyDrive/trained_models/EEGViT-TCNetBaseLine.pt")
print("Model saved to EEGViT-TCNetBaseLine.pt")


Model saved to EEGViT-TCNetBaseLine.pt


In [11]:
#stop runtime
from google.colab import runtime
runtime.unassign()
