In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np
import math

### Helper Functions

In [None]:
def split(ids, train, val, test):
    # proportions of train, val, test
    assert (train+val+test == 1)
    IDs = np.unique(ids)
    num_ids = len(IDs)

    # priority given to the test/val sets
    test_split = math.ceil(test * num_ids)
    val_split = math.ceil(val * num_ids)
    train_split = num_ids - val_split - test_split

    train = np.where(np.isin(ids, IDs[:train_split]))[0]
    val = np.where(np.isin(ids, IDs[train_split:train_split+val_split]))[0]
    test = np.where(np.isin(ids, IDs[train_split+val_split:]))[0]

    return train, val, test

### Connect to Drive

In [None]:
import os
from google.colab import drive

# Step 1: Mount Google Drive
drive.mount('/content/drive')

# Step 2: Define file path
file_path = "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz"

# Step 3: Create the folder if it doesn't exist
os.makedirs(os.path.dirname(file_path), exist_ok=True)

# Step 4: Check if file exists, if not, download it
if not os.path.exists(file_path):
    print("File not found. Downloading...")
    !wget -O "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz" "https://osf.io/download/ge87t/"
else:
    print("File already exists at:", file_path)

Mounted at /content/drive
File already exists at: /content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz


### Dataset Loader

In [None]:
from torch.utils.data import Dataset
import torch
import numpy as np

class EEGEyeNetDataset(Dataset):
        def __init__(self, data_file, transpose=True):
                self.data_file = data_file
                print('loading data...')
                with np.load(self.data_file) as f:  # Load the data array
                        self.trainX = f['EEG']
                        self.trainY = f['labels']
                # Filter data where y[:,1] is between 0 and 800 and y[:,2] is between 0 and 600
                valid_indices = (self.trainY[:, 1] >= 0) & (self.trainY[:, 1] <= 800) & \
                                        (self.trainY[:, 2] >= 0) & (self.trainY[:, 2] <= 600)
                self.trainX = self.trainX[valid_indices]
                self.trainY = self.trainY[valid_indices]
                if transpose:
                        self.trainX = np.transpose(self.trainX, (0, 2, 1))[:, np.newaxis, :, :]
                print(self.trainY)

        def __getitem__(self, index):
                # Read a single sample of data from the data array
                X = torch.from_numpy(self.trainX[index]).float()
                y = torch.from_numpy(self.trainY[index,1:3]).float()
                # Return the tensor data
                return (X,y,index)

        def __len__(self):
                # Compute the number of samples in the data array
                return len(self.trainX)

### Model

In [None]:
import torch
from torch import nn
from transformers import SwinModel, SwinConfig

class EEGSwin(nn.Module):
    def __init__(self):
        super().__init__()

        # Step 1: EEG feature extractor
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=256,
            kernel_size=(1, 36),
            stride=(1, 36),
            padding=(0,2),
            bias=False
        )
        self.batchnorm1 = nn.BatchNorm2d(256, affine=False)

        # Step 2: Swin Transformer configuration
        # Set image_size = (128, 16) → Must be divisible by patch_size = (4, 4)
        config = SwinConfig(
            image_size=(1, 14),          # H x W
            num_channels=256,             # Matches your input channel dimension
            patch_size=(1, 1),           # One patch per location (non-overlapping)
            embed_dim=96,               # Already embedded, so keep it consistent
            depths=[2, 2, 6, 2],         # Standard Swin-Tiny config
            num_heads=[3, 6, 12, 24],    # Number of heads per stage
            window_size=1,               # Can be 7 if it fits into the width (14), else 1 or another divisor
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_path_rate=0.2
        )

        # Step 3: Swin Transformer backbone
        self.swin = SwinModel(config)

        # Step 4: Regression head for (x, y) output
        self.regressor = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Linear(config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 2)  # output x, y
        )

    def forward(self, x):
        # x: [B, 1, 129, 500]
        x = self.conv1(x)           # → [B, 256, 129, 14] (roughly)
        x = self.batchnorm1(x)

        # Resize to fixed shape divisible by patch_size & window_size
        x = nn.functional.interpolate(x, size=(128, 16), mode='bilinear', align_corners=False)  # → [B, 256, 128, 16]

        # Swin expects shape [B, C, H, W]
        swin_output = self.swin(x).last_hidden_state  # → [B, num_patches+1, hidden_dim]

        cls_token = swin_output[:, 0]  # CLS token

        out = self.regressor(cls_token)  # → [B, 2]
        return out


### Config

In [None]:
model = ViTBase()
# EEGEyeNet = EEGEyeNetDataset(file_path)
batch_size = 64
n_epoch = 15
learning_rate = 1e-4

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

In [None]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
from torchinfo import summary
summary(model, input_size=(64, 1, 129, 500))

torch.Size([64, 256, 129, 14])
torch.Size([64, 256, 128, 16])
torch.Size([64, 256, 128, 16])


Layer (type:depth-idx)                                            Output Shape              Param #
EEGSwinScratch                                                    [64, 2]                   --
├─Conv2d: 1-1                                                     [64, 256, 129, 14]        9,216
├─BatchNorm2d: 1-2                                                [64, 256, 129, 14]        --
├─SwinModel: 1-3                                                  [64, 768]                 --
│    └─SwinEmbeddings: 2-1                                        [64, 2048, 96]            --
│    │    └─SwinPatchEmbeddings: 3-1                              [64, 2048, 96]            24,672
│    │    └─LayerNorm: 3-2                                        [64, 2048, 96]            192
│    │    └─Dropout: 3-3                                          [64, 2048, 96]            --
│    └─SwinEncoder: 2-2                                           [64, 32, 768]             --
│    │    └─ModuleList: 3-4          

In [None]:
print(EEGEyeNet.trainX.shape)
print(EEGEyeNet.trainY.shape)

(21448, 1, 129, 500)
(21448, 3)


In [None]:
import sys
def train(model, optimizer):
    '''
        model: model to train
        optimizer: optimizer to update weights
        scheduler: scheduling learning rate, used when finetuning pretrained models
    '''
    torch.cuda.empty_cache()
    train_indices, val_indices, test_indices = split(EEGEyeNet.trainY[:,0],0.7,0.15,0.15)  # indices for the training set
    print('create dataloader...')
    criterion = nn.MSELoss()

    train = Subset(EEGEyeNet,indices=train_indices)
    val = Subset(EEGEyeNet,indices=val_indices)
    test = Subset(EEGEyeNet,indices=test_indices)

    train_loader = DataLoader(train, batch_size=batch_size)
    val_loader = DataLoader(val, batch_size=batch_size)
    test_loader = DataLoader(test, batch_size=batch_size)

    if torch.cuda.is_available():
        gpu_id = 0  # Change this to the desired GPU ID if you have multiple GPUs
        torch.cuda.set_device(gpu_id)
        device = torch.device(f"cuda:{gpu_id}")
    else:
        device = torch.device("cpu")
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)  # Wrap the model with DataParallel
    print("HI")

    model = model.to(device)
    criterion = criterion.to(device)

    # Initialize lists to store losses
    train_losses = []
    val_losses = []
    test_losses = []
    print('training...')
    # Train the model
    for epoch in range(n_epoch):
        model.train()
        epoch_train_loss = 0.0

        for i, (inputs, targets, index) in tqdm(enumerate(train_loader)):
            # Move the inputs and targets to the GPU (if available)
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Compute the outputs and loss for the current batch
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())

            # Compute the gradients and update the parameters
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()

            # Print the loss and accuracy for the current batch
            if i % 100 == 0:
                print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

        epoch_train_loss /= len(train_loader)
        train_losses.append(epoch_train_loss)

        # Evaluate the model on the validation set
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for inputs, targets, index in val_loader:
                # Move the inputs and targets to the GPU (if available)
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Compute the outputs and loss for the current batch
                outputs = model(inputs)
                # print(outputs)
                loss = criterion(outputs.squeeze(), targets.squeeze())
                val_loss += loss.item()


            val_loss /= len(val_loader)
            val_losses.append(val_loss)

            print(f"Epoch {epoch}, Val Loss: {val_loss}")

        with torch.no_grad():
            val_loss = 0.0
            for inputs, targets, index in test_loader:
                # Move the inputs and targets to the GPU (if available)
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Compute the outputs and loss for the current batch
                outputs = model(inputs)

                loss = criterion(outputs.squeeze(), targets.squeeze())
                val_loss += loss.item()

            val_loss /= len(test_loader)
            test_losses.append(val_loss)

            print(f"Epoch {epoch}, test Loss: {val_loss}")

        # if scheduler is not None:
        #     scheduler.step()

train(model, optimizer=optimizer)

create dataloader...
HI
training...


2it [00:00,  6.36it/s]

Epoch 0, Batch 0, Loss: 168953.71875


102it [00:15,  6.72it/s]

Epoch 0, Batch 100, Loss: 141320.640625


202it [00:30,  6.68it/s]

Epoch 0, Batch 200, Loss: 109290.765625


236it [00:35,  6.73it/s]


Epoch 0, Val Loss: 99276.81457270408
Epoch 0, test Loss: 102510.78569240196


1it [00:00,  6.72it/s]

Epoch 1, Batch 0, Loss: 103957.921875


102it [00:15,  6.74it/s]

Epoch 1, Batch 100, Loss: 65892.84375


202it [00:30,  6.74it/s]

Epoch 1, Batch 200, Loss: 43912.18359375


236it [00:34,  6.74it/s]


Epoch 1, Val Loss: 42067.19387755102
Epoch 1, test Loss: 43107.77022058824


1it [00:00,  6.77it/s]

Epoch 2, Batch 0, Loss: 43780.2421875


102it [00:15,  6.73it/s]

Epoch 2, Batch 100, Loss: 39832.87890625


202it [00:29,  6.76it/s]

Epoch 2, Batch 200, Loss: 34653.015625


236it [00:34,  6.74it/s]


Epoch 2, Val Loss: 32371.688815369896
Epoch 2, test Loss: 32568.79465379902


1it [00:00,  6.86it/s]

Epoch 3, Batch 0, Loss: 32640.90625


102it [00:15,  6.78it/s]

Epoch 3, Batch 100, Loss: 31950.80859375


202it [00:29,  6.74it/s]

Epoch 3, Batch 200, Loss: 30435.466796875


236it [00:34,  6.75it/s]


Epoch 3, Val Loss: 29264.857342155614
Epoch 3, test Loss: 28937.95063572304


1it [00:00,  6.84it/s]

Epoch 4, Batch 0, Loss: 28997.55078125


102it [00:15,  6.76it/s]

Epoch 4, Batch 100, Loss: 30530.15625


202it [00:29,  6.73it/s]

Epoch 4, Batch 200, Loss: 29217.40625


236it [00:34,  6.77it/s]


Epoch 4, Val Loss: 28417.538464604593
Epoch 4, test Loss: 27973.214116115196


1it [00:00,  6.83it/s]

Epoch 5, Batch 0, Loss: 27702.3046875


102it [00:15,  6.80it/s]

Epoch 5, Batch 100, Loss: 30044.91015625


202it [00:29,  6.73it/s]

Epoch 5, Batch 200, Loss: 30179.4453125


236it [00:34,  6.79it/s]


Epoch 5, Val Loss: 27658.463289221938
Epoch 5, test Loss: 27208.336818321077


1it [00:00,  6.81it/s]

Epoch 6, Batch 0, Loss: 27475.72265625


102it [00:15,  6.79it/s]

Epoch 6, Batch 100, Loss: 27375.015625


202it [00:29,  6.78it/s]

Epoch 6, Batch 200, Loss: 25996.84765625


236it [00:34,  6.78it/s]


Epoch 6, Val Loss: 26202.570990114797
Epoch 6, test Loss: 25788.845511642157


1it [00:00,  6.81it/s]

Epoch 7, Batch 0, Loss: 25631.58203125


102it [00:15,  6.81it/s]

Epoch 7, Batch 100, Loss: 25414.96484375


202it [00:29,  6.76it/s]

Epoch 7, Batch 200, Loss: 26075.82421875


236it [00:34,  6.80it/s]


Epoch 7, Val Loss: 25025.795320471938
Epoch 7, test Loss: 24300.337239583332


1it [00:00,  6.75it/s]

Epoch 8, Batch 0, Loss: 25051.1484375


102it [00:15,  6.73it/s]

Epoch 8, Batch 100, Loss: 23802.658203125


202it [00:29,  6.80it/s]

Epoch 8, Batch 200, Loss: 26145.9140625


236it [00:34,  6.78it/s]


Epoch 8, Val Loss: 24456.245336415817
Epoch 8, test Loss: 23537.96124387255


1it [00:00,  6.85it/s]

Epoch 9, Batch 0, Loss: 23619.7421875


102it [00:15,  6.80it/s]

Epoch 9, Batch 100, Loss: 24327.8125


202it [00:29,  6.78it/s]

Epoch 9, Batch 200, Loss: 33815.6484375


236it [00:34,  6.78it/s]


Epoch 9, Val Loss: 24884.343909438776
Epoch 9, test Loss: 23820.26171875


1it [00:00,  6.82it/s]

Epoch 10, Batch 0, Loss: 27746.435546875


102it [00:15,  6.78it/s]

Epoch 10, Batch 100, Loss: 21386.12890625


202it [00:29,  6.80it/s]

Epoch 10, Batch 200, Loss: 34015.7265625


236it [00:34,  6.78it/s]


Epoch 10, Val Loss: 23791.287228954083
Epoch 10, test Loss: 22924.85577512255


1it [00:00,  6.86it/s]

Epoch 11, Batch 0, Loss: 25251.4296875


102it [00:15,  6.77it/s]

Epoch 11, Batch 100, Loss: 23054.474609375


202it [00:29,  6.74it/s]

Epoch 11, Batch 200, Loss: 32833.015625


236it [00:34,  6.76it/s]


Epoch 11, Val Loss: 23808.06979432398
Epoch 11, test Loss: 22830.021407781864


1it [00:00,  6.81it/s]

Epoch 12, Batch 0, Loss: 24677.20703125


102it [00:15,  6.76it/s]

Epoch 12, Batch 100, Loss: 21904.197265625


202it [00:29,  6.78it/s]

Epoch 12, Batch 200, Loss: 33882.10546875


236it [00:34,  6.79it/s]


Epoch 12, Val Loss: 23754.199178890307
Epoch 12, test Loss: 22679.16988357843


1it [00:00,  6.67it/s]

Epoch 13, Batch 0, Loss: 24149.19140625


102it [00:15,  6.75it/s]

Epoch 13, Batch 100, Loss: 21546.154296875


202it [00:30,  6.74it/s]

Epoch 13, Batch 200, Loss: 32619.4375


236it [00:34,  6.75it/s]


Epoch 13, Val Loss: 23397.00031887755
Epoch 13, test Loss: 22420.07923560049


1it [00:00,  6.80it/s]

Epoch 14, Batch 0, Loss: 23283.70703125


102it [00:15,  6.73it/s]

Epoch 14, Batch 100, Loss: 20158.294921875


202it [00:29,  6.76it/s]

Epoch 14, Batch 200, Loss: 32745.224609375


236it [00:34,  6.75it/s]


Epoch 14, Val Loss: 23195.046835140307
Epoch 14, test Loss: 22118.709712009804
