In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Create directories
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)
os.makedirs("plots", exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
class DEAPDataset(Dataset):
    def __init__(self, feature_path, label_path, exclude_subject=None, only_subject=None, normalize=True):
        self.features = np.load(feature_path)  # shape: (32, 40, 40, 5, 63)
        self.labels = np.load(label_path)      # shape: (32, 40, 63)
        
        # Compute statistics for normalization if requested
        if normalize:
            all_features = []
            for subj in range(32):
                if exclude_subject is not None and subj == exclude_subject:
                    continue
                if only_subject is not None and subj != only_subject:
                    continue
                
                # Reshape to get all features for this subject
                subj_features = self.features[subj].transpose(0, 1, 3, 2).reshape(-1, 40, 5)  # Result: (N, 40, 5)
                all_features.append(subj_features)
            
            if all_features:  # Check if list is not empty
                all_features = np.concatenate(all_features, axis=0)
                self.mean = np.mean(all_features, axis=0)
                self.std = np.std(all_features, axis=0) + 1e-8
            else:
                # Default if no subjects selected
                self.mean = 0
                self.std = 1
        else:
            self.mean = 0
            self.std = 1
            
        # Prepare samples
        self.samples = []
        for subj in range(32):
            if exclude_subject is not None and subj == exclude_subject:
                continue
            if only_subject is not None and subj != only_subject:
                continue
                
            for trial in range(40):
                for win in range(63):
                    x = self.features[subj, trial, :, :, win]
                    y = self.labels[subj, trial, win]
                    self.samples.append((x, y, subj))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y, subj = self.samples[idx]
        # Normalize the data
        x = (x - self.mean) / self.std
        return torch.tensor(x, dtype=torch.float32), int(y)

    def get_subject_data(self, subject):
        return [(torch.tensor((x - self.mean) / self.std, dtype=torch.float32), int(y))
                for x, y, subj in self.samples if subj == subject]

In [10]:
class CommonFeatureExtractor(nn.Module):
    def __init__(self, input_dim=200):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 64)
        self.bn3 = nn.BatchNorm1d(64)
        self.dropout = nn.Dropout(0.3)
        self.act = nn.LeakyReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # [batch, 40, 5] → [batch, 200]
        x = self.act(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = self.act(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.act(self.bn3(self.fc3(x)))
        return x

class SubjectSpecificMapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(64, 32)
        self.bn = nn.BatchNorm1d(32)
        self.dropout = nn.Dropout(0.2)
        self.act = nn.LeakyReLU()

    def forward(self, x):
        x = self.fc(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.dropout(x)
        return x

class SubjectSpecificClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(32, 4)

    def forward(self, x):
        return self.fc(x)

In [11]:
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        features = F.normalize(features, dim=1)
        batch_size = features.shape[0]
        
        # Handle case where batch size is 1
        if batch_size <= 1:
            return torch.tensor(0.0, device=features.device, requires_grad=True)
            
        sim_matrix = torch.matmul(features, features.T) / self.temperature
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(features.device)
        mask = mask * logits_mask
        
        # Handle case where there are no positive pairs
        if mask.sum() == 0:
            return torch.tensor(0.0, device=features.device, requires_grad=True)
            
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-9)
        loss = - (mask * log_prob).sum() / (mask.sum() + 1e-9)
        return loss

class MMDLoss(nn.Module):
    def __init__(self, kernel_mul=2.0, num_kernels=5):
        super().__init__()
        self.kernel_mul = kernel_mul
        self.num_kernels = num_kernels

    def gaussian_kernel(self, source, target):
        total = torch.cat([source, target], dim=0)
        
        # Handle small batch sizes
        if total.shape[0] <= 1:
            return torch.tensor(0.0, device=source.device, requires_grad=True)
            
        total0 = total.unsqueeze(0)
        total1 = total.unsqueeze(1)
        L2_distance = ((total0 - total1) ** 2).sum(2)
        
        # Prevent division by zero
        bandwidth = torch.mean(L2_distance.detach()) + 1e-8
        bandwidth_list = [bandwidth * (self.kernel_mul ** i) for i in range(self.num_kernels)]
        kernels = [torch.exp(-L2_distance / bw) for bw in bandwidth_list]
        return sum(kernels) / len(kernels)

    def forward(self, source, target):
        source = source.view(source.size(0), -1)
        target = target.view(target.size(0), -1)
        
        # Handle empty batches
        if source.shape[0] == 0 or target.shape[0] == 0:
            return torch.tensor(0.0, device=source.device, requires_grad=True)
            
        kernels = self.gaussian_kernel(source, target)
        batch_size = source.size(0)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        return torch.mean(XX + YY - XY - YX)

class ContrastiveLossLcon2(nn.Module):
    def __init__(self, feature_dim=32, num_classes=4, tau=0.1, gamma=0.5, queue_size=1024):
        super().__init__()
        self.tau = tau
        self.gamma = gamma
        self.prototypes = nn.Parameter(torch.randn(num_classes, feature_dim))
        self.register_buffer("queue", torch.randn(queue_size, feature_dim))
        self.queue = F.normalize(self.queue, dim=-1)

    def forward(self, z_t, pseudo_labels):
        # Handle empty batches
        if z_t.shape[0] == 0:
            return torch.tensor(0.0, device=z_t.device, requires_grad=True)
            
        z_t = F.normalize(z_t, dim=-1)
        device = z_t.device
        pseudo_labels = pseudo_labels.to(device)
        
        pos_proto = self.prototypes.to(device)[pseudo_labels]
        
        # Compute positive and negative logits
        pos_logits = torch.sum(z_t * pos_proto, dim=-1) / self.tau
        
        # Handle case where queue is empty
        if self.queue.shape[0] == 0:
            return self.gamma * F.cross_entropy(pos_logits.unsqueeze(1), torch.zeros(z_t.size(0), dtype=torch.long, device=device))
            
        neg_logits = torch.matmul(z_t, self.queue.to(device).T) / self.tau
        logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
        labels = torch.zeros(z_t.size(0), dtype=torch.long).to(device)
        
        loss = F.cross_entropy(logits, labels)
        self._dequeue_and_enqueue(z_t)
        return self.gamma * loss

    @torch.no_grad()
    def _dequeue_and_enqueue(self, embeddings):
        embeddings = embeddings.detach().to(self.queue.device)
        batch_size = embeddings.size(0)
        queue_size = self.queue.size(0)
        
        if batch_size >= queue_size:
            self.queue = embeddings[-queue_size:]
        else:
            self.queue = torch.cat([self.queue[batch_size:], embeddings], dim=0)

class GeneralizedCrossEntropy(nn.Module):
    def __init__(self, q=0.7, weight=None):
        super().__init__()
        self.q = q
        self.weight = weight  # class weights (tensor)

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        
        # Handle empty batches
        if targets.shape[0] == 0:
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
            
        targets_onehot = F.one_hot(targets, num_classes=probs.shape[1]).float()
        probs = torch.sum(probs * targets_onehot, dim=1)

        if self.weight is not None:
            weights = self.weight[targets]
            loss = (1 - probs ** self.q) / self.q
            return (weights * loss).mean()
        else:
            return ((1 - probs ** self.q) / self.q).mean()

In [12]:
def get_target_batch(dataset, exclude, batch_size=64):
    """Get a batch of data from target subjects"""
    for subj in range(32):
        if subj != exclude:
            data = dataset.get_subject_data(subj)
            if len(data) >= batch_size:
                indices = torch.randperm(len(data))[:batch_size]
                x, y = zip(*[data[i] for i in indices])
                return torch.stack(x), torch.tensor(y)
    
    # Fallback: Return a small batch if no subject has enough data
    all_data = []
    for subj in range(32):
        if subj != exclude:
            all_data.extend(dataset.get_subject_data(subj))
    
    if len(all_data) == 0:
        # Empty tensor with correct shape as a fallback
        empty_sample = next(iter(dataset))
        return torch.zeros((0, *empty_sample[0].shape), dtype=torch.float32), torch.zeros(0, dtype=torch.long)
        
    indices = torch.randperm(len(all_data))[:min(batch_size, len(all_data))]
    x, y = zip(*[all_data[i] for i in indices])
    return torch.stack(x), torch.tensor(y)

def plot_confusion_matrix(y_true, y_pred, subject_idx):
    """Create and save confusion matrix plot"""
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3])
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - Subject {subject_idx+1}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(f'plots/confmat_subject{subject_idx+1}.png')
    plt.close()

In [13]:
def train_model(feature_path, label_path, num_subjects=32, num_epochs=100, 
                warmup_epochs=15, batch_size=64):
    """Main training function for all folds"""
    
    # Hyperparameters
    lambda_1 = 1.0  # For contrastive_loss_con1 during warmup
    lambda_2 = 0.5  # For contrastive_loss_con2 after warmup
    lambda_3_init = 0.1  # Initial value for MMD loss
    
    # Logging containers
    fold_accuracies = []
    global_true = []
    global_pred = []
    
    for test_subject in range(num_subjects):
        print(f"\n🚀 Starting Fold {test_subject+1}/{num_subjects} - Test Subject: s{test_subject+1:02d}")
        
        # Create datasets and loaders
        train_dataset = DEAPDataset(feature_path, label_path, exclude_subject=test_subject, normalize=True)
        test_dataset = DEAPDataset(feature_path, label_path, only_subject=test_subject, normalize=True)
        
        # Create balanced sampler
        all_labels = np.array([label for _, label in train_dataset])
        class_sample_counts = np.bincount(all_labels)
        class_weights = 1. / class_sample_counts
        sample_weights = [class_weights[label] for label in all_labels]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        # Initialize models
        cfe = CommonFeatureExtractor().to(device)
        sfe = SubjectSpecificMapper().to(device)
        ssc = SubjectSpecificClassifier().to(device)
        
        # Initialize losses
        contrastive_loss_con1 = SupervisedContrastiveLoss(temperature=0.07).to(device)
        contrastive_loss_con2 = ContrastiveLossLcon2(tau=0.1).to(device)
        mmd_loss = MMDLoss().to(device)
        gce_loss = GeneralizedCrossEntropy(q=0.7).to(device)
        
        # Initialize optimizer and scheduler
        all_params = list(cfe.parameters()) + list(sfe.parameters()) + list(ssc.parameters())
        optimizer = torch.optim.AdamW(all_params, lr=1e-3, weight_decay=1e-4)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
        
        # Training tracking
        best_acc = 0
        transition_start = warmup_epochs // 2
        
        # Track loss components
        loss_history = {
            'total': [], 'con1': [], 'con2': [], 'mmd': [], 'cls': [],
            'train_acc': [], 'val_acc': []
        }
        
        # Training loop
        for epoch in range(num_epochs):
            cfe.train()
            sfe.train()
            ssc.train()
            
            epoch_losses = {'total': 0, 'con1': 0, 'con2': 0, 'mmd': 0, 'cls': 0}
            total_correct, total_samples = 0, 0
            
            for x_s, y_s in train_loader:
                x_s, y_s = x_s.to(device), y_s.to(device)
                
                # Get target batch
                x_t, y_t = get_target_batch(train_dataset, exclude=test_subject, batch_size=batch_size)
                x_t, y_t = x_t.to(device), y_t.to(device)
                
                # Skip this batch if either source or target is empty
                if x_s.shape[0] == 0 or x_t.shape[0] == 0:
                    continue
                
                # Match batch sizes
                min_size = min(x_s.size(0), x_t.size(0))
                x_s = x_s[:min_size]
                y_s = y_s[:min_size]
                x_t = x_t[:min_size]
                y_t = y_t[:min_size]
                
                optimizer.zero_grad()
                
                # Extract features
                z_s_common = cfe(x_s)
                z_t_common = cfe(x_t)
                
                # Always extract subject-specific features
                z_s_subject = sfe(z_s_common)
                z_t_subject = sfe(z_t_common)
                
                # Domain adaptation loss (always applied)
                loss_mmd = mmd_loss(z_s_common, z_t_common)
                
                # Adjusted MMD weight - gradually increase during warmup
                lambda_3 = lambda_3_init
                if epoch < warmup_epochs:
                    lambda_3 = lambda_3_init * (epoch + 1) / warmup_epochs
                
                # Different training phases
                if epoch < transition_start:
                    # Pure contrastive phase
                    loss_con1 = contrastive_loss_con1(z_s_common, y_s)
                    loss = lambda_1 * loss_con1 + lambda_3 * loss_mmd
                    
                    # Track loss components
                    epoch_losses['con1'] += loss_con1.item()
                    epoch_losses['mmd'] += loss_mmd.item()
                    
                elif epoch < warmup_epochs:
                    # Transition phase - gradually introduce classification
                    transition_factor = (epoch - transition_start) / (warmup_epochs - transition_start)
                    
                    loss_con1 = contrastive_loss_con1(z_s_common, y_s)
                    logits = ssc(z_s_subject)
                    loss_cls = gce_loss(logits, y_s)
                    
                    loss = (1 - transition_factor) * (lambda_1 * loss_con1) + \
                           transition_factor * loss_cls + lambda_3 * loss_mmd
                    
                    # Track loss components
                    epoch_losses['con1'] += loss_con1.item()
                    epoch_losses['cls'] += loss_cls.item()
                    epoch_losses['mmd'] += loss_mmd.item()
                    
                    # Calculate accuracy
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == y_s).sum().item()
                    total_samples += y_s.size(0)
                    
                else:
                    # Full classification phase
                    logits = ssc(z_s_subject)
                    loss_cls = gce_loss(logits, y_s)
                    
                    # Use pseudo labels for target domain
                    with torch.no_grad():
                        pseudo_logits = ssc(z_t_subject)
                        pseudo_labels = torch.argmax(pseudo_logits, dim=1)
                    
                    loss_con2 = contrastive_loss_con2(z_t_subject, pseudo_labels)
                    loss = loss_cls + lambda_2 * loss_con2 + lambda_3 * loss_mmd
                    
                    # Track loss components
                    epoch_losses['cls'] += loss_cls.item()
                    epoch_losses['con2'] += loss_con2.item()
                    epoch_losses['mmd'] += loss_mmd.item()
                    
                    # Calculate accuracy
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == y_s).sum().item()
                    total_samples += y_s.size(0)
                
                # Update total loss
                epoch_losses['total'] += loss.item()
                
                # Backpropagation with gradient clipping
                loss.backward()
                torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
                optimizer.step()
            
            # Update scheduler
            scheduler.step()
            
            # Calculate average losses
            num_batches = len(train_loader)
            avg_total_loss = epoch_losses['total'] / num_batches
            avg_train_acc = total_correct / total_samples if total_samples > 0 else 0
            
            # Update loss history
            loss_history['total'].append(avg_total_loss)
            loss_history['train_acc'].append(avg_train_acc)
            
            for key in ['con1', 'con2', 'mmd', 'cls']:
                if epoch_losses[key] > 0:
                    loss_history[key].append(epoch_losses[key] / num_batches)
                else:
                    loss_history[key].append(0)
            
            # Print progress
            progress_str = f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_total_loss:.4f}"
            if epoch < warmup_epochs:
                progress_str += f" - Con1: {loss_history['con1'][-1]:.4f}"
            else:
                progress_str += f" - Con2: {loss_history['con2'][-1]:.4f}"
            
            progress_str += f" - MMD: {loss_history['mmd'][-1]:.4f}"
            
            if epoch >= transition_start:
                progress_str += f" - Cls: {loss_history['cls'][-1]:.4f}"
                
            progress_str += f" - Train Acc: {avg_train_acc:.4f}"
            print(progress_str)
            
            # Save checkpoint if accuracy improves
            if avg_train_acc > best_acc:
                best_acc = avg_train_acc
                print(f"🏆 New best accuracy for fold {test_subject+1}: {best_acc:.4f}")
                torch.save({
                    'epoch': epoch + 1,
                    'cfe_state_dict': cfe.state_dict(),
                    'sfe_state_dict': sfe.state_dict(),
                    'ssc_state_dict': ssc.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': avg_total_loss,
                    'accuracy': avg_train_acc
                }, f"checkpoints/fold{test_subject+1}_best.pt")
            
            # Evaluation after each epoch
            if epoch % 5 == 0 or epoch == num_epochs - 1:
                cfe.eval()
                sfe.eval()
                ssc.eval()
                
                all_preds, all_labels = [], []
                
                with torch.no_grad():
                    for x_test, y_test in test_loader:
                        x_test = x_test.to(device)
                        z_common = cfe(x_test)
                        z_subject = sfe(z_common)
                        logits = ssc(z_subject)
                        preds = torch.argmax(logits, dim=1)
                        
                        all_preds.extend(preds.cpu().numpy())
                        all_labels.extend(y_test.numpy())
                
                val_acc = accuracy_score(all_labels, all_preds)
                loss_history['val_acc'].append(val_acc)
                print(f"📊 Validation Accuracy on Subject {test_subject+1}: {val_acc:.4f}")
        
        # Final evaluation
        cfe.eval()
        sfe.eval()
        ssc.eval()
        
        all_preds, all_labels = [], []
        
        with torch.no_grad():
            for x_test, y_test in test_loader:
                x_test = x_test.to(device)
                z_common = cfe(x_test)
                z_subject = sfe(z_common)
                logits = ssc(z_subject)
                preds = torch.argmax(logits, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y_test.numpy())
        
        final_acc = accuracy_score(all_labels, all_preds)
        fold_accuracies.append(final_acc)
        print(f"🎯 Final Accuracy on Subject {test_subject+1}: {final_acc:.4f}")
        
        # Plot confusion matrix
        plot_confusion_matrix(all_labels, all_preds, test_subject)
        global_true.extend(all_labels)
        global_pred.extend(all_preds)
        
        # Save loss history for this fold
        np.save(f"logs/fold{test_subject+1}_loss_history.npy", loss_history)
    
    # Final report
    print("\n✅ Training complete.")
    print("Subject-wise accuracies:", fold_accuracies)
    print("Average Accuracy:", np.mean(fold_accuracies))
    print("Overall Accuracy:", accuracy_score(global_true, global_pred))
    
    # Plot overall confusion matrix
    plot_confusion_matrix(global_true, global_pred, 32)  # Use index 32 for global
    
    return fold_accuracies, global_true, global_pred

if __name__ == "__main__":
    # Set paths to your data
    feature_path = 'E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_features.npy'
    label_path = 'E:/FYP/Finalise Fyp/EEg-based-Emotion-Recognition/de_labels.npy'
    
    # Run training
    fold_accuracies, global_true, global_pred = train_model(
        feature_path=feature_path,
        label_path=label_path,
        num_subjects=32,
        num_epochs=100,
        warmup_epochs=15,
        batch_size=64
    )


🚀 Starting Fold 1/32 - Test Subject: s01
Epoch [1/100] - Loss: 4.5787 - Con1: 4.5787 - MMD: 0.0039 - Train Acc: 0.0000
📊 Validation Accuracy on Subject 1: 0.3230
Epoch [2/100] - Loss: 4.1389 - Con1: 4.1388 - MMD: 0.0053 - Train Acc: 0.0000
Epoch [3/100] - Loss: 4.1335 - Con1: 4.1334 - MMD: 0.0044 - Train Acc: 0.0000
Epoch [4/100] - Loss: 4.1283 - Con1: 4.1282 - MMD: 0.0042 - Train Acc: 0.0000
Epoch [5/100] - Loss: 4.1245 - Con1: 4.1244 - MMD: 0.0037 - Train Acc: 0.0000
Epoch [6/100] - Loss: 4.1188 - Con1: 4.1187 - MMD: 0.0031 - Train Acc: 0.0000
📊 Validation Accuracy on Subject 1: 0.3206
Epoch [7/100] - Loss: 4.1159 - Con1: 4.1158 - MMD: 0.0029 - Train Acc: 0.0000
Epoch [8/100] - Loss: 4.1098 - Con1: 4.1096 - MMD: 0.0026 - Cls: 0.8992 - Train Acc: 0.2523
🏆 New best accuracy for fold 1: 0.2523
Epoch [9/100] - Loss: 3.6853 - Con1: 4.1083 - MMD: 0.0025 - Cls: 0.7226 - Train Acc: 0.4818
🏆 New best accuracy for fold 1: 0.4818
Epoch [10/100] - Loss: 3.2505 - Con1: 4.1064 - MMD: 0.0026 - Cls

KeyboardInterrupt: 