In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
import json
import copy
from tqdm import tqdm
import sys
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
import os
from google.colab import drive

# Step 1: Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Step 2: Define file path
drive_home = '/content/drive/MyDrive/'
file_path = "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz"

# Step 3: Create the folder if it doesn't exist
os.makedirs(os.path.dirname(file_path), exist_ok=True)

# Step 4: Check if file exists, if not, download it
if not os.path.exists(file_path):
    print("File not found. Downloading...")
    !wget -O "/content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz" "https://osf.io/download/ge87t/"
else:
    print("File already exists at:", file_path)

File already exists at: /content/drive/MyDrive/datasets/Position_task_with_dots_synchronised_min.npz


### Load Data to Numpy Array

In [None]:
data = np.load(file_path)

trainX = data['EEG']
trainY = data['labels'][:,1:] # The first column are the Id-s, the second and third are position x and y which we use
ids = data['labels'][:, 0] # Participant Ids
print(f"trainX.shape: {trainX.shape}")
print(f"trainY.shape: {trainY.shape}")

### Visualize

In [None]:
import matplotlib.pyplot as plt

quadrants = []
for x, y in trainY:
    if x > 0 and y > 0:
        quadrants.append(1)
    elif x < 0 and y > 0:
        quadrants.append(2)
    elif x < 0 and y < 0:
        quadrants.append(3)
    elif x > 0 and y < 0:
        quadrants.append(4)
    else:
        quadrants.append(0)  # On axis
quadrants = np.array(quadrants)

# Plot
colors = ['gray', 'red', 'blue', 'green', 'purple']
labels = ['Axis', 'Q1', 'Q2', 'Q3', 'Q4']

plt.figure(figsize=(8, 8))
for q in range(5):
    idx = quadrants == q
    plt.scatter(trainY[idx, 0], trainY[idx, 1], label=labels[q], alpha=0.6, s=10)

plt.axhline(0, color='black', linewidth=1)
plt.axvline(0, color='black', linewidth=1)
plt.title("trainY Distribution Across 4 Quadrants")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()

In [None]:
# Filter data where trainY[:,0] is between 0 and 800 and trainY[:,1] is between 0 and 600
valid_indices = (trainY[:, 0] >= 0) & (trainY[:, 0] <= 800) & \
                    (trainY[:, 1] >= 0) & (trainY[:, 1] <= 600)
trainX = trainX[valid_indices]
trainY = trainY[valid_indices]
ids = ids[valid_indices]
trainX = np.transpose(trainX, (0, 2, 1))[:, np.newaxis, :, :]
print(trainX.shape)
print(trainY.shape)

(21448, 1, 129, 500)
(21448, 2)


### After Outlier Removal

In [None]:
import matplotlib.pyplot as plt

quadrants = []
for x, y in trainY:
    if x > 0 and y > 0:
        quadrants.append(1)
    elif x < 0 and y > 0:
        quadrants.append(2)
    elif x < 0 and y < 0:
        quadrants.append(3)
    elif x > 0 and y < 0:
        quadrants.append(4)
    else:
        quadrants.append(0)  # On axis
quadrants = np.array(quadrants)

# Plot
colors = ['gray', 'red', 'blue', 'green', 'purple']
labels = ['Axis', 'Q1', 'Q2', 'Q3', 'Q4']

plt.figure(figsize=(8, 8))
for q in range(5):
    idx = quadrants == q
    plt.scatter(trainY[idx, 0], trainY[idx, 1], label=labels[q], alpha=0.6, s=10)

plt.axhline(0, color='black', linewidth=1)
plt.axvline(0, color='black', linewidth=1)
plt.title("trainY Distribution Across 4 Quadrants")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()

### Split Data

In [None]:
import math
import numpy as np

def split(ids, train, val, test):
    # proportions of train, val, test
    assert (train+val+test == 1)
    IDs = np.unique(ids)
    num_ids = len(IDs)

    # priority given to the test/val sets
    test_split = math.ceil(test * num_ids)
    val_split = math.ceil(val * num_ids)
    train_split = num_ids - val_split - test_split

    train = np.where(np.isin(ids, IDs[:train_split]))[0]
    val = np.where(np.isin(ids, IDs[train_split:train_split+val_split]))[0]
    test = np.where(np.isin(ids, IDs[train_split+val_split:]))[0]

    return train, val, test

train, val, test = split(ids, 0.7, 0.15, 0.15)
X_train, y_train = trainX[train], trainY[train]
X_val, y_val = trainX[val], trainY[val]
X_test, y_test = trainX[test], trainY[test]

print(f"X_train.shape:{X_train.shape} y_train.shape: {y_train.shape}")
print(f"X_val.shape:{X_val.shape} y_val.shape: {y_val.shape}")
print(f"X_test.shape:{X_test.shape} y_test.shape: {y_test.shape}")

X_train.shape:(15071, 1, 129, 500) y_train.shape: (15071, 2)
X_val.shape:(3132, 1, 129, 500) y_val.shape: (3132, 2)
X_test.shape:(3245, 1, 129, 500) y_test.shape: (3245, 2)


### Create DataLoaders

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Convert NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Shape: (N, 2)

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val_tensor, y_val_tensor), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=batch_size)

### Model

In [None]:
import torch
import transformers
from transformers import ViTModel
import torch
from torch import nn
import transformers

class EEGViT_pretrained(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=256,
            kernel_size=(1, 36),
            stride=(1, 36),
            padding=(0,2),
            bias=False
        )
        self.batchnorm1 = nn.BatchNorm2d(256, False)
        model_name = "google/vit-base-patch16-224"
        config = transformers.ViTConfig.from_pretrained(model_name)
        config.update({'num_channels': 256})
        config.update({'image_size': (129,14)})
        config.update({'patch_size': (8,1)})

        model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
        model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(256, 768, kernel_size=(8, 1), stride=(8, 1), padding=(0,0), groups=256)
        model.classifier=torch.nn.Sequential(torch.nn.Linear(768,1000,bias=True),
                                     torch.nn.Dropout(p=0.1),
                                     torch.nn.Linear(1000,2,bias=True))
        self.ViT = model

    def forward(self,x):
        x=self.conv1(x)
        x=self.batchnorm1(x)
        x=self.ViT.forward(x).logits

        return x

### Config

In [None]:
import torch
import torch.nn as nn

class MeanEuclideanDistance(nn.Module):
    def __init__(self):
        super(MeanEuclideanDistance, self).__init__()

    def forward(self, y_pred, y_true):
        return torch.mean(torch.linalg.norm(torch.sub(y_true, y_pred), dim=1))

In [None]:
# Set Seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

set_seed(42)


In [None]:
model = EEGViT_pretrained()

# Load the saved state dict
checkpoint = torch.load("/content/drive/MyDrive/trained_models/encoder_weights_direction_task.pt")

# Get the current model's state dict
model_dict = model.state_dict()

# Filter out keys that belong to ViT only
vit_weights = {k: v for k, v in checkpoint.items() if k.startswith('ViT.')}

# Update model's state dict with only ViT weights
model_dict.update(vit_weights)
model.load_state_dict(model_dict)

# Check which parts are trainable
# for name, param in model.named_parameters():
#     print(f"{name} requires_grad = {param.requires_grad}")


batch_size = 64
n_epoch = 15
learning_rate = 1e-4

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

torch.cuda.empty_cache()
if torch.cuda.is_available():
    gpu_id = 0  # Change this to the desired GPU ID if you have multiple GPUs
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}")
else:
    device = torch.device("cpu")
if torch.cuda.device_count() > 1:
    print("Multiple GPUs Available")
    model = nn.DataParallel(model)  # Wrap the model with DataParallel

model = model.to(device)
criterion = criterion.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- vit.embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 256, 8, 1]) in the model instantiated
- vit.embeddings.position_embeddings: found shape torch.Size([1, 197, 768]) in the checkpoint and torch.Size([1, 225, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Training

In [None]:
# Initialize lists to store losses
train_losses = []
val_losses = []
test_losses = []
best_val_loss = float('inf')
best_model_wts = None

print('training...')
# Train the model
for epoch in range(n_epoch):
    model.train()
    epoch_train_loss = 0.0

    for i, (inputs, targets) in tqdm(enumerate(train_loader), desc=f"Epoch {epoch}/{n_epoch}"):
        # inputs = inputs.to(device).unsqueeze(1).permute(0, 1, 3, 2)
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), targets.squeeze())

        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

        # Optional: print loss every 100 batches
        # if i % 100 == 0:
        #     print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

    epoch_train_loss /= len(train_loader)
    train_losses.append(epoch_train_loss)
    print(f"Epoch {epoch}, Train Loss: {epoch_train_loss:.4f}, RMSE: {(epoch_train_loss ** 0.5)/2:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())
            val_loss += loss.item()

        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}")

        # Save best model based on val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

    # Test
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())
            test_loss += loss.item()

        test_loss /= len(test_loader)
        test_losses.append(test_loss)
        rmse = (test_loss ** 0.5)/2
        print(f"Epoch {epoch}, Test Loss (MSE): {test_loss:.4f}, RMSE: {rmse:.4f}")

    if scheduler is not None:
        scheduler.step()

# Load best model weights
if best_model_wts is not None:
    model.load_state_dict(best_model_wts)
    print("Best model loaded with val loss:", best_val_loss)

# Save best model
torch.save(model.state_dict(), "/content/drive/MyDrive/trained_models/abs_pos_EEGViT_OnlyVitWeightsPretrained.pt")
print("Best model saved as 'abs_pos_EEGViT_OnlyVitWeightsPretrained.pt'.")

loss_dict = {
    "train_losses": train_losses,
    "val_losses": val_losses,
    "test_losses": test_losses
}

with open("/content/drive/MyDrive/trained_models/loss_logs_abs_pos_EEGViT_OnlyVitWeightsPretrained.json", "w") as f:
    json.dump(loss_dict, f, indent=2)

training...


Epoch 0/15: 236it [01:56,  2.03it/s]


Epoch 0, Train Loss: 26175.6750, RMSE: 161.7890
Epoch 0, Val Loss: 19018.2867
Epoch 0, Test Loss (MSE): 16653.0510, RMSE: 129.0467


Epoch 1/15: 236it [01:56,  2.02it/s]


Epoch 1, Train Loss: 18296.4746, RMSE: 135.2645
Epoch 1, Val Loss: 18411.7602
Epoch 1, Test Loss (MSE): 15160.2576, RMSE: 123.1270


Epoch 2/15: 236it [01:56,  2.02it/s]


Epoch 2, Train Loss: 16432.1804, RMSE: 128.1881
Epoch 2, Val Loss: 18261.9425
Epoch 2, Test Loss (MSE): 15456.6580, RMSE: 124.3248


Epoch 3/15: 236it [01:56,  2.02it/s]


Epoch 3, Train Loss: 15709.4877, RMSE: 125.3375
Epoch 3, Val Loss: 16428.8419
Epoch 3, Test Loss (MSE): 13370.9353, RMSE: 115.6328


Epoch 4/15: 236it [01:56,  2.02it/s]


Epoch 4, Train Loss: 14723.5241, RMSE: 121.3405
Epoch 4, Val Loss: 18497.2147
Epoch 4, Test Loss (MSE): 15272.7162, RMSE: 123.5828


Epoch 5/15: 236it [01:56,  2.02it/s]


Epoch 5, Train Loss: 13587.7986, RMSE: 116.5667
Epoch 5, Val Loss: 16806.3903
Epoch 5, Test Loss (MSE): 13262.4611, RMSE: 115.1628


Epoch 6/15: 236it [01:57,  2.02it/s]


Epoch 6, Train Loss: 11213.2569, RMSE: 105.8927
Epoch 6, Val Loss: 16490.7170
Epoch 6, Test Loss (MSE): 12839.2260, RMSE: 113.3103


Epoch 7/15: 236it [01:57,  2.02it/s]


Epoch 7, Train Loss: 10355.9040, RMSE: 101.7640
Epoch 7, Val Loss: 16636.9557
Epoch 7, Test Loss (MSE): 13034.6936, RMSE: 114.1696


Epoch 8/15: 236it [01:56,  2.02it/s]


Epoch 8, Train Loss: 9747.9310, RMSE: 98.7316
Epoch 8, Val Loss: 16641.3889
Epoch 8, Test Loss (MSE): 13279.7555, RMSE: 115.2378


Epoch 9/15: 236it [01:57,  2.02it/s]


Epoch 9, Train Loss: 9273.4044, RMSE: 96.2985
Epoch 9, Val Loss: 16983.9535
Epoch 9, Test Loss (MSE): 13477.2005, RMSE: 116.0913


Epoch 10/15: 236it [01:57,  2.02it/s]


Epoch 10, Train Loss: 8664.1676, RMSE: 93.0815
Epoch 10, Val Loss: 17214.5759
Epoch 10, Test Loss (MSE): 13929.4198, RMSE: 118.0230


Epoch 11/15: 236it [01:57,  2.02it/s]


Epoch 11, Train Loss: 8066.4877, RMSE: 89.8136
Epoch 11, Val Loss: 17405.1405
Epoch 11, Test Loss (MSE): 14100.2522, RMSE: 118.7445


Epoch 12/15: 236it [01:57,  2.02it/s]


Epoch 12, Train Loss: 7331.7094, RMSE: 85.6254
Epoch 12, Val Loss: 17609.5500
Epoch 12, Test Loss (MSE): 14057.3775, RMSE: 118.5638


Epoch 13/15: 236it [01:57,  2.02it/s]


Epoch 13, Train Loss: 7231.9425, RMSE: 85.0408
Epoch 13, Val Loss: 17782.2136
Epoch 13, Test Loss (MSE): 14025.6345, RMSE: 118.4299


Epoch 14/15: 236it [01:57,  2.02it/s]


Epoch 14, Train Loss: 7151.0008, RMSE: 84.5636
Epoch 14, Val Loss: 17897.3158
Epoch 14, Test Loss (MSE): 14466.7253, RMSE: 120.2777
Best model loaded with val loss: 16428.841876594386
Best model saved as 'abs_pos_EEGViT_OnlyVitWeightsPretrained.pt'.
