In [1]:
import numpy as np
import torch
import torch.nn as nn
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


### Load Data to Numpy Array

In [40]:
data = np.load('data\Position_task_with_dots_synchronised_min.npz')

trainX = data['EEG']
trainY = data['labels'][:,1:] # The first column are the Id-s, the second and third are position x and y which we use
ids = data['labels'][:, 0] # Participant Ids
print(f"trainX.shape: {trainX.shape}")
print(f"trainY.shape: {trainY.shape}")
trainY[0]

  data = np.load('data\Position_task_with_dots_synchronised_min.npz')


trainX.shape: (21464, 500, 129)
trainY.shape: (21464, 2)


array([408.1, 315.1])

### Split Data

In [41]:
import math
import numpy as np

def split(ids, train, val, test):
    # proportions of train, val, test
    assert (train+val+test == 1)
    IDs = np.unique(ids)
    num_ids = len(IDs)

    # priority given to the test/val sets
    test_split = math.ceil(test * num_ids)
    val_split = math.ceil(val * num_ids)
    train_split = num_ids - val_split - test_split

    train = np.where(np.isin(ids, IDs[:train_split]))[0]
    val = np.where(np.isin(ids, IDs[train_split:train_split+val_split]))[0]
    test = np.where(np.isin(ids, IDs[train_split+val_split:]))[0]
    
    return train, val, test

train, val, test = split(ids, 0.7, 0.15, 0.15)
X_train, y_train = trainX[train], trainY[train]
X_val, y_val = trainX[val], trainY[val]
X_test, y_test = trainX[test], trainY[test]

print(f"X_train.shape:{X_train.shape} y_train.shape: {y_train.shape}")
print(f"X_val.shape:{X_val.shape} y_val.shape: {y_val.shape}")
print(f"X_test.shape:{X_test.shape} y_test.shape: {y_test.shape}")

X_train.shape:(15076, 500, 129) y_train.shape: (15076, 2)
X_val.shape:(3134, 500, 129) y_val.shape: (3134, 2)
X_test.shape:(3254, 500, 129) y_test.shape: (3254, 2)


### Create DataLoaders

In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Convert NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Shape: (N, 2)

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val_tensor, y_val_tensor), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=batch_size)

### Model

In [43]:
import torch
import torch.nn as nn
import sys

class EEGEncoderRaw(nn.Module):
    def __init__(self, pretrained_encoder_encoder):
        super().__init__()

        # Step 1: Temporal Convolution (1 x 36 kernel across time)
        self.conv_temporal = nn.Conv2d(
            in_channels=1, 
            out_channels=129,
            kernel_size=(1, 36),
            stride=(1, 36),
            padding=(0, 2),
            bias=False
        )
        self.batchnorm_temporal = nn.BatchNorm2d(129, affine=False)

        # Step 2: Depthwise Spatial Convolution (8 x 1 kernel over channels)
        self.conv_spatial = nn.Conv2d(
            in_channels=129,
            out_channels=129,
            kernel_size=(8, 1),
            stride=(8, 1),
            groups=129,  # depthwise
            bias=False
        )
        self.batchnorm_spatial = nn.BatchNorm2d(129, affine=False)

        # Transformer encoder expects d_model = 129
        self.encoder = pretrained_encoder_encoder

        # Simple regression head
        self.regressor = nn.Sequential(
            nn.Linear(129, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # Predict (x, y)
        )

    def forward(self, x):
        #print(f"Input x.shape: {x.shape}")  # [B, 1, C, T]

        x = self.conv_temporal(x)
        # print(f"After conv_temporal: {x.shape}")  # [B, 129, C, T']

        x = self.batchnorm_temporal(x)
        # print(f"After batchnorm_temporal: {x.shape}")

        x = self.conv_spatial(x)
        # print(f"After conv_spatial: {x.shape}")  # [B, 129, H', W']

        x = self.batchnorm_spatial(x)
        # print(f"After batchnorm_spatial: {x.shape}")

        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  # [B, seq_len, 129]
        # print(f"After view & permute: {x.shape}")

        x = self.encoder(x)
        # print(f"After transformer encoder: {x.shape}")  # [B, seq_len, 129]

        pooled = x.mean(dim=1)
        # print(f"After mean pooling: {pooled.shape}")  # [B, 129]

        coords = self.regressor(pooled)
        # print(f"Output coords: {coords.shape}")  # [B, 2]
        # sys.exit()
        return coords


In [44]:
def get_pretrained_transformer_encoder():
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=129,          # same as input_dim
        nhead=3,
        dim_feedforward=512,
        batch_first=True
    )
    transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
    transformer_encoder.load_state_dict(torch.load("pretrained_encoder.pt"))
    return transformer_encoder


batch_size = 64
n_epoch = 15
learning_rate = 1e-4
# Recreate the encoder.encoder module (must match architecture used in training)
pretrained_encoder_encoder = get_pretrained_transformer_encoder()

# Create model
model = EEGEncoderRaw(pretrained_encoder_encoder)
criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

### Training

In [46]:
import torch
import copy
from tqdm import tqdm
import sys

torch.cuda.empty_cache()
model = model.to(device)
criterion = criterion.to(device)

# Initialize lists to store losses
train_losses = []
val_losses = []
test_losses = []
best_val_loss = float('inf')
best_model_wts = None

print('Training...')

# Train the model
for epoch in range(n_epoch):
    model.train()
    epoch_train_loss = 0.0

    for i, (inputs, targets) in tqdm(enumerate(train_loader), desc=f"Epoch {epoch}/{n_epoch}"):
        inputs = inputs.to(device).unsqueeze(1).permute(0, 1, 3, 2)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), targets.squeeze())

        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()

        # Optional: print loss every 5 batches
        # if i % 5 == 0:
        #     print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

    epoch_train_loss /= len(train_loader)
    train_losses.append(epoch_train_loss)
    print(f"Epoch {epoch}, Train Loss: {epoch_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device).unsqueeze(1).permute(0, 1, 3, 2)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())
            val_loss += loss.item()

        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}")

        # Save best model based on val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

    # Test
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device).unsqueeze(1).permute(0, 1, 3, 2)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), targets.squeeze())
            test_loss += loss.item()

        test_loss /= len(test_loader)
        test_losses.append(test_loss)
        rmse = test_loss ** 0.5
        print(f"Epoch {epoch}, Test Loss (MSE): {test_loss:.4f}, RMSE: {rmse:.4f}")

    if scheduler is not None:
        scheduler.step()

# Load best model weights
if best_model_wts is not None:
    model.load_state_dict(best_model_wts)
    print("Best model loaded with val loss:", best_val_loss)

# Save best model
torch.save(model.state_dict(), "best_model_abs_pos.pt")
print("Best model saved as 'best_model_abs_pos.pt'.")

# Save encoder weights
if hasattr(model, 'encoder'):
    torch.save(model.encoder.state_dict(), "encoder_weights_abs_pos.pt")
    print("Encoder weights saved as 'encoder_weights_abs_pos.pt'.")
else:
    print("Model does not have an 'encoder' attribute to save.")

Training...


Epoch 0/15: 236it [00:05, 39.83it/s]


Epoch 0, Train Loss: 36608.2921
Epoch 0, Val Loss: 36629.5651
Epoch 0, Test Loss (MSE): 37566.8305, RMSE: 193.8216


Epoch 1/15: 236it [00:05, 42.16it/s]


Epoch 1, Train Loss: 36588.3983
Epoch 1, Val Loss: 36607.3192
Epoch 1, Test Loss (MSE): 37540.3856, RMSE: 193.7534


Epoch 2/15: 236it [00:05, 42.36it/s]


Epoch 2, Train Loss: 36579.5176
Epoch 2, Val Loss: 36584.4030
Epoch 2, Test Loss (MSE): 37513.0541, RMSE: 193.6829


Epoch 3/15: 236it [00:05, 42.26it/s]


Epoch 3, Train Loss: 36557.9608
Epoch 3, Val Loss: 36582.8209
Epoch 3, Test Loss (MSE): 37510.4892, RMSE: 193.6762


Epoch 4/15: 236it [00:05, 42.20it/s]


Epoch 4, Train Loss: 36572.7492
Epoch 4, Val Loss: 36580.1362
Epoch 4, Test Loss (MSE): 37507.5886, RMSE: 193.6688


Epoch 5/15: 236it [00:05, 42.09it/s]


Epoch 5, Train Loss: 36548.9888
Epoch 5, Val Loss: 36577.4642
Epoch 5, Test Loss (MSE): 37504.8498, RMSE: 193.6617


Epoch 6/15: 236it [00:05, 41.78it/s]


Epoch 6, Train Loss: 36565.2146
Epoch 6, Val Loss: 36575.5962
Epoch 6, Test Loss (MSE): 37501.8849, RMSE: 193.6540


Epoch 7/15: 236it [00:05, 42.04it/s]


Epoch 7, Train Loss: 36558.2975
Epoch 7, Val Loss: 36571.9887
Epoch 7, Test Loss (MSE): 37498.6473, RMSE: 193.6457


Epoch 8/15: 236it [00:05, 41.57it/s]


Epoch 8, Train Loss: 36547.6038
Epoch 8, Val Loss: 36569.3184
Epoch 8, Test Loss (MSE): 37495.5831, RMSE: 193.6378


Epoch 9/15: 236it [00:05, 41.73it/s]


Epoch 9, Train Loss: 36547.1070
Epoch 9, Val Loss: 36570.1582
Epoch 9, Test Loss (MSE): 37495.5486, RMSE: 193.6377


Epoch 10/15: 236it [00:05, 41.99it/s]


Epoch 10, Train Loss: 36553.9963
Epoch 10, Val Loss: 36569.6699
Epoch 10, Test Loss (MSE): 37495.2926, RMSE: 193.6370


Epoch 11/15: 236it [00:05, 41.76it/s]


Epoch 11, Train Loss: 36540.7189
Epoch 11, Val Loss: 36569.1493
Epoch 11, Test Loss (MSE): 37494.9010, RMSE: 193.6360


Epoch 12/15: 236it [00:05, 42.03it/s]


Epoch 12, Train Loss: 36549.5697
Epoch 12, Val Loss: 36569.0646
Epoch 12, Test Loss (MSE): 37494.6138, RMSE: 193.6353


Epoch 13/15: 236it [00:05, 41.95it/s]


Epoch 13, Train Loss: 36529.4564
Epoch 13, Val Loss: 36569.1344
Epoch 13, Test Loss (MSE): 37494.3856, RMSE: 193.6347


Epoch 14/15: 236it [00:05, 42.11it/s]


Epoch 14, Train Loss: 36531.6994
Epoch 14, Val Loss: 36567.9583
Epoch 14, Test Loss (MSE): 37494.0500, RMSE: 193.6338
Best model loaded with val loss: 36567.9583466199
Best model saved as 'best_model_abs_pos.pt'.
Encoder weights saved as 'encoder_weights_abs_pos.pt'.
