In [1]:
import numpy as np
import torch
import torch.nn as nn
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


# Task 1: Left-Right 

In [128]:
data = np.load('data\LR_task_with_antisaccade_synchronised_min.npz')
print(data['labels'].shape)
print('Converted to')
print(data['labels'][:, 1].shape)

trainX = data['EEG']
trainY = data['labels'][:, 1]
ids = data['labels'][:, 0] # Participant Ids

  data = np.load('data\LR_task_with_antisaccade_synchronised_min.npz')


(30842, 2)
Converted to
(30842,)


In [129]:
print(trainX.shape)
print(trainY.shape)

(30842, 500, 129)
(30842,)


In [2]:
class SimpleEncoder(nn.Module):
    def __init__(self, input_dim=129, num_layers=2, num_heads=3, dim_feedforward=512):
        super(SimpleEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            batch_first=True  # Input shape: [batch, seq, dim]
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Classification head: input_dim -> 1 (binary classification)
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # For binary output between 0 and 1
        )

    def forward(self, x):
        encoded = self.encoder(x)  # Shape: [batch, seq_len, input_dim]
        pooled = encoded.mean(dim=1)  # Global average pooling over time
        return self.classifier(pooled)  # Shape: [batch, 1]

In [3]:
def split(ids, train, val, test):
    # proportions of train, val, test
    assert (train+val+test == 1)
    IDs = np.unique(ids)
    num_ids = len(IDs)

    # priority given to the test/val sets
    test_split = math.ceil(test * num_ids)
    val_split = math.ceil(val * num_ids)
    train_split = num_ids - val_split - test_split

    train = np.where(np.isin(ids, IDs[:train_split]))[0]
    val = np.where(np.isin(ids, IDs[train_split:train_split+val_split]))[0]
    test = np.where(np.isin(ids, IDs[train_split+val_split:]))[0]
    
    return train, val, test

In [133]:
import math
import numpy as np

train, val, test = split(ids, 0.7, 0.15, 0.15)
X_train, y_train = trainX[train], trainY[train]
X_val, y_val = trainX[val], trainY[val]
X_test, y_test = trainX[test], trainY[test]

print(f"X_train.shape:{X_train.shape} y_train.shape: {y_train.shape}")
print(f"X_val.shape:{X_val.shape} y_val.shape: {y_val.shape}")
print(f"X_test.shape:{X_test.shape} y_test.shape: {y_test.shape}")

X_train.shape:(21042, 500, 129) y_train.shape: (21042,)
X_val.shape:(4980, 500, 129) y_val.shape: (4980,)
X_test.shape:(4820, 500, 129) y_test.shape: (4820,)


In [137]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Convert NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Shape: (N, 1)

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val_tensor, y_val_tensor), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=batch_size)

In [138]:
# Initialize model, loss function, optimizer
model = SimpleEncoder().to(device)
criterion = nn.BCELoss()  # Binary Cross Entropy
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 50
best_val_acc = 0.0
best_model_state = None

In [15]:
import logging
from datetime import datetime
import os

# Create logs directory if not exists
os.makedirs("logs", exist_ok=True)

# Set filename based on current date and time
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"logs/train_log_{timestamp}.log"

# Setup logging
logging.basicConfig(
    filename=log_filename,
    filemode="w",
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO
)

# Also print to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console.setFormatter(formatter)
logging.getLogger().addHandler(console)

logger = logging.getLogger()


In [142]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * X_batch.size(0)

    avg_train_loss = total_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            preds = (outputs > 0.5).float()
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
        val_acc = correct / total

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()

    # Test evaluation each epoch
    with torch.no_grad():
        correct = 0
        total = 0
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            preds = (outputs > 0.5).float()
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
        test_acc = correct / total

    logger.info(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Acc: {val_acc*100:.4f} | Test Acc: {test_acc*100:.4f}")

# Save best model to file
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    torch.save(best_model_state, "best_model.pt")
    logger.info(f"\n✅ Best model saved as 'best_model.pt' with val acc: {best_val_acc*100:.4f}%")

    # Save just the encoder for reuse
    torch.save(model.encoder.state_dict(), "pretrained_encoder.pt")
    logger.info("✅ Encoder weights saved separately to 'pretrained_encoder.pt'")

2025-04-19 04:19:06,636 - INFO - Epoch 1/50 | Train Loss: 0.0931 | Val Acc: 97.8514 | Test Acc: 96.6390
2025-04-19 04:19:25,140 - INFO - Epoch 2/50 | Train Loss: 0.0844 | Val Acc: 97.2691 | Test Acc: 96.9295
2025-04-19 04:19:42,723 - INFO - Epoch 3/50 | Train Loss: 0.0831 | Val Acc: 97.1084 | Test Acc: 95.8091
2025-04-19 04:20:00,850 - INFO - Epoch 4/50 | Train Loss: 0.0734 | Val Acc: 97.0482 | Test Acc: 96.5768
2025-04-19 04:20:18,970 - INFO - Epoch 5/50 | Train Loss: 0.0733 | Val Acc: 97.6305 | Test Acc: 96.8880
2025-04-19 04:20:37,117 - INFO - Epoch 6/50 | Train Loss: 0.0682 | Val Acc: 97.4096 | Test Acc: 96.7635
2025-04-19 04:20:55,465 - INFO - Epoch 7/50 | Train Loss: 0.0710 | Val Acc: 96.8876 | Test Acc: 95.7469
2025-04-19 04:21:13,578 - INFO - Epoch 8/50 | Train Loss: 0.0645 | Val Acc: 97.7309 | Test Acc: 97.1162
2025-04-19 04:21:32,143 - INFO - Epoch 9/50 | Train Loss: 0.0567 | Val Acc: 97.7510 | Test Acc: 96.8050
2025-04-19 04:21:50,279 - INFO - Epoch 10/50 | Train Loss: 0.053

In [None]:
# Load best model and set to eval mode
model.load_state_dict(torch.load("best_model.pt"))
model.eval()

# Move test data to device
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)

# Get predictions on full test set
with torch.no_grad():
    outputs = model(X_test_tensor)
    preds = (outputs > 0.5).float()

# Compute correct and wrong counts
correct_preds = (preds == y_test_tensor).sum().item()
total_preds = y_test_tensor.size(0)
wrong_preds = total_preds - correct_preds

# Print counts
logger.info(f"\n✅ Total Correct Predictions: {correct_preds}")
logger.info(f"❌ Total Wrong Predictions:   {wrong_preds}")
logger.info(f"📊 Test Accuracy:             {(correct_preds / total_preds) * 100:.2f}%")

# Print predictions for first 10 test samples
logger.info("\n📊 Predictions vs Ground Truth for first 10 test samples:\n")
for i in range(10):
    pred_val = preds[i].item()
    actual_val = y_test_tensor[i].item()
    logger.info(f"Sample {i+1:02d} | Predicted: {int(pred_val)} | Actual: {int(actual_val)}")


# TASK 2: Direction

In [18]:
data = np.load('data\Direction_task_with_dots_synchronised_min_15_perc.npz')

trainX = data['EEG']
trainY = data['labels'][:, 1:3]
ids = data['labels'][:, 0] # ID
print(f"trainX.shape: {trainX.shape}")
print(f"trainY.shape: {trainY.shape}")

  data = np.load('data\Direction_task_with_dots_synchronised_min_15_perc.npz')


trainX.shape: (2674, 500, 129)
trainY.shape: (2674, 2)


In [19]:
import math
import numpy as np

train, val, test = split(ids, 0.8, 0.1, 0.1)
X_train, y_train = trainX[train], trainY[train]
X_val, y_val = trainX[val], trainY[val]
X_test, y_test = trainX[test], trainY[test]

print(f"X_train.shape:{X_train.shape} y_train.shape: {y_train.shape}")
print(f"X_val.shape:{X_val.shape} y_val.shape: {y_val.shape}")
print(f"X_test.shape:{X_test.shape} y_test.shape: {y_test.shape}")

X_train.shape:(2157, 500, 129) y_train.shape: (2157, 2)
X_val.shape:(233, 500, 129) y_val.shape: (233, 2)
X_test.shape:(284, 500, 129) y_test.shape: (284, 2)


In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Convert NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)  # Shape: (N, 2)

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val_tensor, y_val_tensor), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=batch_size)

In [8]:
class MultiTaskRegressor(nn.Module):
    def __init__(self, encoder, input_dim=129):
        super(MultiTaskRegressor, self).__init__()
        self.encoder = encoder

        # Shared head for feature extraction
        self.shared_head = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )

        # Separate heads for amplitude and angle
        self.amplitude_head = nn.Linear(128, 1)
        self.angle_head = nn.Linear(128, 1)

    def forward(self, x):
        encoded = self.encoder.encoder(x)  # Use encoder.encoder to get the transformer output
        pooled = encoded.mean(dim=1)       # Same as in SimpleEncoder
        features = self.shared_head(pooled)
        amplitude = self.amplitude_head(features)
        angle = self.angle_head(features)
        return amplitude, angle


# MultiTaskRegressor Setup

In [None]:
# Load encoder and its weights
encoder = SimpleEncoder(input_dim=129).to(device)
state_dict = torch.load("pretrained_encoder.pt", map_location=device)
encoder.encoder.load_state_dict(state_dict)

# Wrap in multitask regressor
model = MultiTaskRegressor(encoder=encoder, input_dim=129).to(device)

# Angle loss with correct angle error formula (torch)
criterion_angle = lambda pred, target: torch.mean(
    torch.square(torch.atan2(torch.sin(target - pred), torch.cos(target - pred)))
)

# Define loss functions
criterion_amplitude = nn.MSELoss()

# Choose weighting method
learn_uncertainty = True  # ← set to False if you want fixed weights

if learn_uncertainty:
    # Learnable log variances for adaptive weighting
    log_sigma_amp = torch.nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))
    log_sigma_ang = torch.nn.Parameter(torch.tensor(0.0, requires_grad=True, device=device))
    optimizer = optim.Adam(
        list(model.parameters()) + [log_sigma_amp, log_sigma_ang], lr=1e-4
    )
else:
    # Manual weights
    w_amp = 1.0
    w_ang = 10000.0  # Tune based on your data
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

# # Define Optimizer
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training setup
num_epochs = 50
best_val_loss = float('inf')
best_model_state = None



# Training for MultiTaskRegressor

In [150]:
logger = logging.getLogger()
for epoch in range(num_epochs):
    model.train()
    train_loss, train_amp_loss, train_ang_loss = 0, 0, 0

    for X_batch, y_batch in train_loader:
        y_batch = y_batch.squeeze(1)
        amp_batch, ang_batch = y_batch[:, 0], y_batch[:, 1]

        X_batch = X_batch.to(device)
        amp_batch = amp_batch.to(device)
        ang_batch = ang_batch.to(device)

        optimizer.zero_grad()
        pred_amp, pred_ang = model(X_batch)

        loss_amp = criterion_amplitude(pred_amp, amp_batch)
        loss_ang = criterion_angle(pred_ang, ang_batch)

        if learn_uncertainty:
            loss = (1 / (2 * torch.exp(log_sigma_amp))) * loss_amp + \
                   (1 / (2 * torch.exp(log_sigma_ang))) * loss_ang + \
                   0.5 * (log_sigma_amp + log_sigma_ang)
        else:
            loss = w_amp * loss_amp + w_ang * loss_ang

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * X_batch.size(0)
        train_amp_loss += loss_amp.item() * X_batch.size(0)
        train_ang_loss += loss_ang.item() * X_batch.size(0)

    n_train = len(train_loader.dataset)
    avg_train_loss = train_loss / n_train
    avg_train_amp_loss = train_amp_loss / n_train
    avg_train_ang_loss = train_ang_loss / n_train

    # Validation phase
    model.eval()
    val_loss, val_amp_loss, val_ang_loss = 0, 0, 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            y_batch = y_batch.squeeze(1)
            amp_batch, ang_batch = y_batch[:, 0], y_batch[:, 1]

            X_batch = X_batch.to(device)
            amp_batch = amp_batch.to(device)
            ang_batch = ang_batch.to(device)

            pred_amp, pred_ang = model(X_batch)

            loss_amp = criterion_amplitude(pred_amp, amp_batch)
            loss_ang = criterion_angle(pred_ang, ang_batch)

            if learn_uncertainty:
                loss = (1 / (2 * torch.exp(log_sigma_amp))) * loss_amp + \
                       (1 / (2 * torch.exp(log_sigma_ang))) * loss_ang + \
                       0.5 * (log_sigma_amp + log_sigma_ang)
            else:
                loss = w_amp * loss_amp + w_ang * loss_ang

            val_loss += loss.item() * X_batch.size(0)
            val_amp_loss += loss_amp.item() * X_batch.size(0)
            val_ang_loss += loss_ang.item() * X_batch.size(0)

    n_val = len(val_loader.dataset)
    avg_val_loss = val_loss / n_val
    avg_val_amp_loss = val_amp_loss / n_val
    avg_val_ang_loss = val_ang_loss / n_val

    # Test phase
    test_loss, test_amp_loss, test_ang_loss = 0, 0, 0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            y_batch = y_batch.squeeze(1)
            amp_batch, ang_batch = y_batch[:, 0], y_batch[:, 1]

            X_batch = X_batch.to(device)
            amp_batch = amp_batch.to(device)
            ang_batch = ang_batch.to(device)

            pred_amp, pred_ang = model(X_batch)

            loss_amp = criterion_amplitude(pred_amp, amp_batch)
            loss_ang = criterion_angle(pred_ang, ang_batch)

            if learn_uncertainty:
                loss = (1 / (2 * torch.exp(log_sigma_amp))) * loss_amp + \
                       (1 / (2 * torch.exp(log_sigma_ang))) * loss_ang + \
                       0.5 * (log_sigma_amp + log_sigma_ang)
            else:
                loss = w_amp * loss_amp + w_ang * loss_ang

            test_loss += loss.item() * X_batch.size(0)
            test_amp_loss += loss_amp.item() * X_batch.size(0)
            test_ang_loss += loss_ang.item() * X_batch.size(0)

    n_test = len(test_loader.dataset)
    avg_test_loss = test_loss / n_test
    avg_test_amp_loss = test_amp_loss / n_test
    avg_test_ang_loss = test_ang_loss / n_test

    logger.info(f"Epoch {epoch+1}/{num_epochs}")
    logger.info(f"🔹 Train Loss: {avg_train_loss:.4f} (Amplitude: {avg_train_amp_loss:.4f}, Angle: {avg_train_ang_loss:.4f})")
    logger.info(f"🔸 Val   Loss: {avg_val_loss:.4f} (Amplitude: {avg_val_amp_loss:.4f}, Angle: {avg_val_ang_loss:.4f})")
    logger.info(f"🔻 Test  Loss: {avg_test_loss:.4f} (Amplitude: {avg_test_amp_loss:.4f}, Angle: {avg_test_ang_loss:.4f})")
    if learn_uncertainty:
        logger.info(f"   ↪ log_sigma_amp: {log_sigma_amp.item():.4f}, log_sigma_ang: {log_sigma_ang.item():.4f}")
    logger.info("-" * 80)

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict()

# Save best model and encoder state
if best_model_state is not None:
    torch.save(best_model_state, "best_multitask_model.pt")
    logger.info(f"\n✅ Best multitask model saved as 'best_multitask_model.pt' with val loss: {best_val_loss:.4f}")

    torch.save(model.encoder.encoder.state_dict(), "best_finetuned_encoder_task2.pt")
    logger.info("🧠 Best fine-tuned encoder saved as 'best_finetuned_encoder_task2.pt'")

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
2025-04-19 05:16:45,888 - INFO - Epoch 1/50
2025-04-19 05:16:46,251 - INFO - 🔹 Train Loss: 59230.2428 (Amplitude: 119765.0651, Angle: 3.2758)
2025-04-19 05:16:46,252 - INFO - 🔸 Val   Loss: 53101.6296 (Amplitude: 108577.6045, Angle: 2.8294)
2025-04-19 05:16:46,253 - INFO - 🔻 Test  Loss: 46237.5209 (Amplitude: 94541.8933, Angle: 3.0339)
2025-04-19 05:16:46,277 - INFO -    ↪ log_sigma_amp: 0.0221, log_sigma_ang: 0.0201
2025-04-19 05:16:46,278 - INFO - --------------------------------------------------------------------------------
2025-04-19 05:17:08,583 - INFO - Epoch 2/50
2025-04-19 05:17:08,591 - INFO - 🔹 Train Loss: 51394.1584 (Amplitude: 106136.5641, Angle: 2.9565)
2025-04-19 05:17:08,591 - INFO - 🔸 Val   Loss: 43457.0700 (Amplitude: 90673.4280