In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import scipy.io
from scipy.signal import butter, lfilter
import scipy.signal

In [10]:
import numpy as np
import scipy.signal
import scipy.io
import os

# ==== CONFIG ====
data_path = "E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/data_preprocessed_matlab/"
save_features = "E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy"
save_labels = "E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy"

# ==== Frequency bands ====
freq_bands = {
    "delta": (1, 4),
    "theta": (4, 8),
    "alpha": (8, 14),
    "beta": (14, 30),
    "gamma": (31, 50),
}

# Differential Entropy formula
def compute_de(signal):
    variance = np.var(signal, axis=-1, keepdims=True)
    de = 0.5 * np.log(2 * np.pi * np.e * variance)
    return de.squeeze()

# Feature extraction and save
def extract_and_save_de_features(subject_data, subject_labels, feature_path, label_path, fs=128, window_size=128):
    num_subjects, num_trials, num_channels, num_samples = subject_data.shape
    num_bands = len(freq_bands)
    num_windows = num_samples // window_size

    de_features = np.zeros((num_subjects, num_trials, num_channels, num_bands, num_windows))

    for subj in range(num_subjects):
        for trial in range(num_trials):
            for ch in range(num_channels):
                signal = subject_data[subj, trial, ch, :]
                for b_idx, (band, (low, high)) in enumerate(freq_bands.items()):
                    sos = scipy.signal.butter(4, [low, high], btype="bandpass", fs=fs, output="sos")
                    filtered_signal = scipy.signal.sosfilt(sos, signal)

                    segmented = np.array(np.split(filtered_signal, num_windows, axis=-1))
                    de_features[subj, trial, ch, b_idx, :] = compute_de(segmented)

    np.save(feature_path, de_features)
    print(f"✅ DE features saved to: {feature_path}")

    # Expand labels to match window dimension
    expanded_labels = np.repeat(subject_labels[:, :, np.newaxis], num_windows, axis=2)  # Shape: (32, 40, 63)
    np.save(label_path, expanded_labels)
    print(f"✅ Labels saved to: {label_path}")

# ==== Loading raw EEG data ====
subject_data = []
subject_labels = []

for i in range(1, 33):
    mat = scipy.io.loadmat(f"{data_path}s{i:02d}.mat")
    eeg = mat["data"]  # Shape: (40, 40, 8064)
    labels = mat["labels"]  # Shape: (40, 4)

    # Map labels to 4 emotion classes (Valence + Arousal)
    valence = labels[:, 0]
    arousal = labels[:, 1]
    emotion_class = np.zeros_like(valence, dtype=int)
    emotion_class[(valence >= 5) & (arousal >= 5)] = 3
    emotion_class[(valence < 5) & (arousal >= 5)] = 1
    emotion_class[(valence >= 5) & (arousal < 5)] = 2
    emotion_class[(valence < 5) & (arousal < 5)] = 0

    subject_data.append(eeg)
    subject_labels.append(emotion_class)

subject_data = np.array(subject_data)      # Shape: (32, 40, 40, 8064)
subject_labels = np.array(subject_labels)  # Shape: (32, 40)

print("✅ Raw EEG Shape:", subject_data.shape)
print("✅ Emotion Labels Shape:", subject_labels.shape)

# ==== Extract and save DE features + labels ====
extract_and_save_de_features(subject_data, subject_labels, save_features, save_labels)

print("🎉 All processing done! DE features and labels are now saved and ready for training.")


✅ Raw EEG Shape: (32, 40, 40, 8064)
✅ Emotion Labels Shape: (32, 40)
✅ DE features saved to: E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy
✅ Labels saved to: E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy
🎉 All processing done! DE features and labels are now saved and ready for training.


In [11]:
import numpy as np
de_features = np.load(save_features)
de_labels = np.load(save_labels) 

print("✅ DE features loaded:", de_features.shape)
print("✅ DE labels loaded:", de_labels.shape)

✅ DE features loaded: (32, 40, 40, 5, 63)
✅ DE labels loaded: (32, 40, 63)


In [12]:
import numpy as np

# Load your saved labels
de_labels = np.load("E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy")

# Flatten labels to 1D array
flat_labels = de_labels.flatten()

# Count each class
classes, counts = np.unique(flat_labels, return_counts=True)

print("🧠 Class distribution (after segmentation):")
for cls, count in zip(classes, counts):
    print(f"Class {cls}: {count} samples")

print("\n💡 Total samples:", np.sum(counts))


🧠 Class distribution (after segmentation):
Class 0: 16380 samples
Class 1: 18648 samples
Class 2: 16758 samples
Class 3: 28854 samples

💡 Total samples: 80640


## DEAP DATASET CLASS FOR FEATURES & LABELS

In [13]:
import torch
from torch.utils.data import Dataset

class DEAPDataset(Dataset):
    def __init__(self, feature_path, label_path, transform=None):
        self.features = np.load(feature_path)  # Shape: (32, 40, 40, 5, 63)
        self.labels = np.load(label_path)      # Shape: (32, 40, 63)
        self.transform = transform

        self.samples = []
        num_subjects, num_trials, _, _, num_windows = self.features.shape

        for subj in range(num_subjects):
            for trial in range(num_trials):
                for win in range(num_windows):
                    label = self.labels[subj, trial, win]
                    self.samples.append((subj, trial, win, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        subj, trial, win, label = self.samples[idx]
        feature = self.features[subj, trial, :, :, win]  # Shape: (40 channels, 5 bands)

        if self.transform:
            feature = self.transform(feature)

        return torch.tensor(feature, dtype=torch.float32), int(label)


In [14]:
from torch.utils.data import DataLoader, WeightedRandomSampler

# Instantiate dataset
dataset = DEAPDataset('E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy',
                      'E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy')

# Get all labels for computing class weights
all_labels = np.array([label for _, _, _, label in dataset.samples])
classes, counts = np.unique(all_labels, return_counts=True)

# Inverse frequency as weight
class_weights = 1. / counts
sample_weights = [class_weights[label] for label in all_labels]

# Create the sampler
sampler = WeightedRandomSampler(weights=sample_weights,
                                 num_samples=len(sample_weights),
                                 replacement=True)

# Create DataLoader
train_loader = DataLoader(dataset, batch_size=64, sampler=sampler)

print("✅ Balanced DataLoader ready for training.")


✅ Balanced DataLoader ready for training.


## MODEL ARCHITECTURE

In [15]:
import torch
import torch.nn as nn

class CommonFeatureExtractor(nn.Module):
    def __init__(self, input_dim=200):  # DEAP: 40 channels × 5 bands
        super(CommonFeatureExtractor, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.act = nn.LeakyReLU()  # As per paper, not ReLU

    def forward(self, x):
        x = x.view(x.size(0), -1)         # Flatten input (40, 5) → [200]
        x = self.act(self.fc1(x))         # Input → 256
        x = self.act(self.fc2(x))         # 256 → 128
        x = self.act(self.fc3(x))         # 128 → 64 (final embedding)
        return x


In [16]:
class SubjectSpecificMapper(nn.Module):
    def __init__(self, input_dim=64, output_dim=32):
        super(SubjectSpecificMapper, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.act = nn.LeakyReLU()

    def forward(self, x):
        return self.act(self.fc(x))  # 64 → 32 with LeakyReLU


In [17]:
class SubjectSpecificClassifier(nn.Module):
    def __init__(self, input_dim=32, num_classes=4):  # You use 4 classes (DEAP)
        super(SubjectSpecificClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)  # 32 → 4, raw logits (no activation)


## Contrastive Loss 1

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        Args:
            features: Tensor of shape [batch_size, feature_dim]
            labels:   Tensor of shape [batch_size] (int class labels)
        Returns:
            loss: scalar
        """
        device = features.device
        features = F.normalize(features, dim=1)
        batch_size = features.shape[0]

        # Cosine similarity matrix
        sim_matrix = torch.matmul(features, features.T) / self.temperature

        # Mask: same-label entries → 1, others → 0
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)

        # Remove self-similarity from denominator
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        mask = mask * logits_mask

        # Numerator: exp(similarity with positives)
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-9)

        # Final loss: average only over positive pairs
        loss = - (mask * log_prob).sum() / (mask.sum() + 1e-9)
        return loss


## MMD LOSS

In [20]:
import torch
import torch.nn as nn

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, num_kernels=5):
        super(MMDLoss, self).__init__()
        self.kernel_type = kernel_type
        self.kernel_mul = kernel_mul
        self.num_kernels = num_kernels

    def gaussian_kernel(self, source, target):
        total = torch.cat([source, target], dim=0)  # [n + m, d]
        total0 = total.unsqueeze(0)  # [1, n+m, d]
        total1 = total.unsqueeze(1)  # [n+m, 1, d]
        L2_distance = ((total0 - total1) ** 2).sum(2)  # [n+m, n+m]

        # Use multiple Gaussian kernels
        bandwidth = torch.mean(L2_distance).detach()
        bandwidth_list = [bandwidth * (self.kernel_mul ** i) for i in range(self.num_kernels)]
        kernels = [torch.exp(-L2_distance / bw) for bw in bandwidth_list]
        return sum(kernels) / len(kernels)

    def forward(self, source, target):
        source = source.view(source.size(0), -1)  # [bs, d]
        target = target.view(target.size(0), -1)
        kernels = self.gaussian_kernel(source, target)

        batch_size = source.size(0)
        XX = kernels[:batch_size, :batch_size]  # source-source
        YY = kernels[batch_size:, batch_size:]  # target-target
        XY = kernels[:batch_size, batch_size:]  # source-target
        YX = kernels[batch_size:, :batch_size]  # target-source

        loss = torch.mean(XX + YY - XY - YX)
        return loss


## Contrastive Loss 2

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLossLcon2(nn.Module):
    def __init__(self, feature_dim=32, num_classes=4, tau=0.3, gamma=0.5, queue_size=1024):
        super(ContrastiveLossLcon2, self).__init__()
        self.tau = tau
        self.gamma = gamma
        self.num_classes = num_classes
        self.queue_size = queue_size

        # Initialize class prototypes (μ_c)
        self.prototypes = nn.Parameter(torch.randn(num_classes, feature_dim))  # Shape: [4, 32]

        # Initialize memory queue for negative samples
        self.register_buffer("queue", torch.randn(queue_size, feature_dim))  # Shape: [1024, 32]
        self.queue = F.normalize(self.queue, dim=-1)

    def forward(self, z_t, pseudo_labels):
        """
        Args:
            z_t: target embeddings from SFE (B, 32)
            pseudo_labels: class indices (B,)
        """
        z_t = F.normalize(z_t, dim=-1)  # Normalize subject-specific features

        # 1. Positive logits (with prototypes)
        pos_proto = self.prototypes[pseudo_labels]  # (B, 32)
        pos_logits = torch.sum(z_t * pos_proto, dim=-1) / self.tau  # cosine similarity / tau

        # 2. Negative logits (with queue)
        neg_logits = torch.matmul(z_t, self.queue.T) / self.tau  # (B, Q)

        # 3. Combine positives and negatives into logits
        logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)  # (B, 1+Q)
        labels = torch.zeros(z_t.size(0), dtype=torch.long).to(z_t.device)  # Positives at index 0

        # 4. Compute loss
        loss = F.cross_entropy(logits, labels)

        # 5. Update the queue
        self._dequeue_and_enqueue(z_t)

        return self.gamma * loss  # γ-scaled

    @torch.no_grad()
    def _dequeue_and_enqueue(self, embeddings):
        # Detach to avoid gradients
        embeddings = embeddings.detach()

        batch_size = embeddings.size(0)
        queue = self.queue

        if batch_size >= self.queue_size:
            self.queue = embeddings[-self.queue_size:]
        else:
            self.queue = torch.cat([queue[batch_size:], embeddings], dim=0)


## GCE LOSS

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GeneralizedCrossEntropy(nn.Module):
    def __init__(self, q=0.7):
        super(GeneralizedCrossEntropy, self).__init__()
        self.q = q

    def forward(self, logits, targets):
        """
        logits: [B, C]
        targets: [B] (class indices)
        """
        probs = F.softmax(logits, dim=1)
        probs_true = probs[torch.arange(logits.size(0)), targets]  # pick p_y

        if self.q == 1.0:
            loss = 1.0 - probs_true
        else:
            loss = (1.0 - probs_true.pow(self.q)) / self.q

        return loss.mean()


## Training LOOP

In [27]:
# Full MSCL Training Loop (Unified Script)
# DEAP Dataset + Phase 1 Losses + Model + LOSO

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
from sklearn.metrics import accuracy_score

# === DEAP Dataset Loader ===
class DEAPDataset(Dataset):
    def __init__(self, feature_path, label_path, exclude_subject=None, only_subject=None):
        self.features = np.load(feature_path)  # shape: (32, 40, 40, 5, 63)
        self.labels = np.load(label_path)      # shape: (32, 40, 63)

        self.samples = []
        for subj in range(32):
            if exclude_subject is not None and subj == exclude_subject:
                continue
            if only_subject is not None and subj != only_subject:
                continue
            for trial in range(40):
                for win in range(63):
                    x = self.features[subj, trial, :, :, win]
                    y = self.labels[subj, trial, win]
                    self.samples.append((x, y, subj))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y, subj = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), int(y)

    def get_subject_data(self, subject):
        return [(torch.tensor(x, dtype=torch.float32), int(y))
                for x, y, subj in self.samples if subj == subject]

# === Models ===
class CommonFeatureExtractor(nn.Module):
    def __init__(self, input_dim=200):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.act = nn.LeakyReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # [batch, 40, 5] → [batch, 200]
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.act(self.fc3(x))
        return x

class SubjectSpecificMapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(64, 32)
        self.act = nn.LeakyReLU()

    def forward(self, x):
        return self.act(self.fc(x))

class SubjectSpecificClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(32, 4)

    def forward(self, x):
        return self.fc(x)

# === Losses ===
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        features = F.normalize(features, dim=1)
        batch_size = features.shape[0]
        sim_matrix = torch.matmul(features, features.T) / self.temperature
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(features.device)
        mask = mask * logits_mask
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-9)
        loss = - (mask * log_prob).sum() / (mask.sum() + 1e-9)
        return loss

class MMDLoss(nn.Module):
    def __init__(self, kernel_mul=2.0, num_kernels=5):
        super().__init__()
        self.kernel_mul = kernel_mul
        self.num_kernels = num_kernels

    def gaussian_kernel(self, source, target):
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0)
        total1 = total.unsqueeze(1)
        L2_distance = ((total0 - total1) ** 2).sum(2)
        bandwidth = torch.mean(L2_distance).detach()
        bandwidth_list = [bandwidth * (self.kernel_mul ** i) for i in range(self.num_kernels)]
        kernels = [torch.exp(-L2_distance / bw) for bw in bandwidth_list]
        return sum(kernels) / len(kernels)

    def forward(self, source, target):
        source = source.view(source.size(0), -1)
        target = target.view(target.size(0), -1)
        kernels = self.gaussian_kernel(source, target)
        batch_size = source.size(0)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        return torch.mean(XX + YY - XY - YX)

class ContrastiveLossLcon2(nn.Module):
    def __init__(self, feature_dim=32, num_classes=4, tau=0.3, gamma=0.5, queue_size=1024):
        super().__init__()
        self.tau = tau
        self.gamma = gamma
        self.prototypes = nn.Parameter(torch.randn(num_classes, feature_dim))
        self.register_buffer("queue", torch.randn(queue_size, feature_dim))
        self.queue = F.normalize(self.queue, dim=-1)

    def forward(self, z_t, pseudo_labels):
        z_t = F.normalize(z_t, dim=-1)
        pos_proto = self.prototypes[pseudo_labels]
        pos_logits = torch.sum(z_t * pos_proto, dim=-1) / self.tau
        neg_logits = torch.matmul(z_t, self.queue.T) / self.tau
        logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
        labels = torch.zeros(z_t.size(0), dtype=torch.long).to(z_t.device)
        loss = F.cross_entropy(logits, labels)
        self._dequeue_and_enqueue(z_t)
        return self.gamma * loss

    @torch.no_grad()
    def _dequeue_and_enqueue(self, embeddings):
        embeddings = embeddings.detach()
        batch_size = embeddings.size(0)
        if batch_size >= self.queue.size(0):
            self.queue = embeddings[-self.queue.size(0):]
        else:
            self.queue = torch.cat([self.queue[batch_size:], embeddings], dim=0)

class GeneralizedCrossEntropy(nn.Module):
    def __init__(self, q=0.7):
        super().__init__()
        self.q = q

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        probs_true = probs[torch.arange(logits.size(0)), targets]
        if self.q == 1.0:
            loss = 1.0 - probs_true
        else:
            loss = (1.0 - probs_true.pow(self.q)) / self.q
        return loss.mean()

# === Helper ===
def get_target_batch(dataset, exclude, batch_size=64):
    for subj in range(32):
        if subj != exclude:
            data = dataset.get_subject_data(subj)
            if len(data) >= batch_size:
                indices = torch.randperm(len(data))[:batch_size]
                x, y = zip(*[data[i] for i in indices])
                return torch.stack(x), torch.tensor(y)

# === Training Loop ===
# [INSERTED FROM EXISTING LOOP ABOVE -- UNCHANGED]


In [29]:
# Full Training Loop for MSCL Model (Phase 1) - DEAP Dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import numpy as np
import os

# Create checkpoint directory
os.makedirs("checkpoints", exist_ok=True)

# === Assume these are already defined ===
# - CommonFeatureExtractor
# - SubjectSpecificMapper
# - SubjectSpecificClassifier
# - SupervisedContrastiveLoss
# - ContrastiveLossLcon2
# - MMDLoss
# - GeneralizedCrossEntropy
# - DEAPDataset (with LOSO support)
# - WeightedRandomSampler
# - get_target_batch()

# Hyperparameters
num_epochs = 600
lambda_1, lambda_2, lambda_3 = 0.1, 0.1, 0.1
num_subjects = 32
batch_size = 64

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Logging containers
fold_accuracies = []
epoch_losses = []
epoch_accuracies = []

for test_subject in range(num_subjects):
    print(f"\n🚀 Starting Fold {test_subject+1}/32 - Test Subject: s{test_subject+1:02d}")

    # === Create datasets and loaders ===
    train_dataset = DEAPDataset('E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy', 'E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy', exclude_subject=test_subject)
    test_dataset = DEAPDataset('E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy', 'E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy', only_subject=test_subject)

    all_labels = np.array([label for _, label in train_dataset])
    class_sample_counts = np.bincount(all_labels)
    class_weights = 1. / class_sample_counts
    sample_weights = [class_weights[label] for label in all_labels]
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # === Initialize model and losses ===
    cfe = CommonFeatureExtractor().to(device)
    sfe = SubjectSpecificMapper().to(device)
    ssc = SubjectSpecificClassifier().to(device)

    contrastive_loss_con1 = SupervisedContrastiveLoss()
    contrastive_loss_con2 = ContrastiveLossLcon2()
    mmd_loss = MMDLoss()
    gce_loss = GeneralizedCrossEntropy(q=0.7)

    optimizer = torch.optim.Adam(
        list(cfe.parameters()) + list(sfe.parameters()) + list(ssc.parameters()), lr=1e-3)

    # === Training ===
    for epoch in range(num_epochs):
        cfe.train(), sfe.train(), ssc.train()
        total_loss, total_correct, total_samples = 0, 0, 0

        for x_s, y_s in train_loader:
            x_s, y_s = x_s.to(device), y_s.to(device)
            x_t, y_t = get_target_batch(train_dataset, exclude=test_subject,batch_size=batch_size)
            x_t, y_t = x_t.to(device), y_t.to(device)

            optimizer.zero_grad()
            
            min_size = min(x_s.size(0), x_t.size(0))
            x_s = x_s[:min_size]
            x_t = x_t[:min_size]

            z_s_common = cfe(x_s)
            z_t_common = cfe(x_t)
            loss_mmd = mmd_loss(z_s_common, z_t_common)

            if epoch < 400:
                loss_con1 = contrastive_loss_con1(z_s_common, y_s)
                loss = lambda_1 * loss_con1 + lambda_3 * loss_mmd
            else:
                z_s_subject = sfe(z_s_common)
                z_t_subject = sfe(z_t_common)
                logits = ssc(z_s_subject)
                loss_cls = gce_loss(logits, y_s)
                pseudo_labels = torch.argmax(ssc(z_t_subject), dim=1)
                loss_con2 = contrastive_loss_con2(z_t_subject, pseudo_labels)
                loss = lambda_2 * loss_con2 + lambda_3 * loss_mmd + loss_cls

                preds = torch.argmax(logits, dim=1)
                total_correct += (preds == y_s).sum().item()
                total_samples += y_s.size(0)

            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        avg_loss = total_loss / len(train_loader)
        acc = total_correct / total_samples if total_samples > 0 else 0
        epoch_losses.append(avg_loss)
        epoch_accuracies.append(acc)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f} - Train Acc: {acc:.4f}")

        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'cfe_state_dict': cfe.state_dict(),
            'sfe_state_dict': sfe.state_dict(),
            'ssc_state_dict': ssc.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }
        torch.save(checkpoint, f"checkpoints/fold{test_subject+1}_epoch{epoch+1}.pt")

    # === Evaluation ===
    cfe.eval(), sfe.eval(), ssc.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x_test, y_test in test_loader:
            x_test = x_test.to(device)
            z_common = cfe(x_test)
            z_subject = sfe(z_common)
            logits = ssc(z_subject)
            preds = torch.argmax(logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y_test.numpy())

    acc = accuracy_score(all_labels, all_preds)
    fold_accuracies.append(acc)
    print(f"🎯 Accuracy on Subject {test_subject+1}: {acc:.4f}")

# === Final Report ===
print("\n✅ Training complete.")
print("Average Accuracy:", np.mean(fold_accuracies))



🚀 Starting Fold 1/32 - Test Subject: s01
Epoch [1/600] - Loss: 0.4156 - Train Acc: 0.0000
Epoch [2/600] - Loss: 0.4156 - Train Acc: 0.0000
Epoch [3/600] - Loss: 0.4156 - Train Acc: 0.0000
Epoch [4/600] - Loss: 0.4156 - Train Acc: 0.0000
Epoch [5/600] - Loss: 0.4155 - Train Acc: 0.0000


KeyboardInterrupt: 