Data Splitting

In [None]:
# Data splitting with stratification and subject-wise splitting
from sklearn.model_selection import StratifiedShuffleSplit

# 70% Train, 20% Validation, 10% Test
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
train_val_idx, test_idx = next(sss.split(X, y))
X_temp, X_test = X[train_val_idx], X[test_idx]
y_temp, y_test = y[train_val_idx], y[test_idx]

# split train_val (X_temp, y_temp) into 70% train and 20% val (of original)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.22, random_state=42)  # 0.22 * 0.9 ≈ 0.2 overall
train_idx, val_idx = next(sss.split(X_temp, y_temp))
X_train, X_val = X_temp[train_idx], X_temp[val_idx]
y_train, y_val = y_temp[train_idx], y_temp[val_idx]

# Convert to tensors
X_train = torch.FloatTensor(X_train).unsqueeze(1) 
X_val = torch.FloatTensor(X_val).unsqueeze(1)
X_test = torch.FloatTensor(X_test).unsqueeze(1)
y_train = torch.LongTensor(y_train)
y_val = torch.LongTensor(y_val)
y_test = torch.LongTensor(y_test)

# Create datasets
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

print("Enhanced data split completed:")
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"X_train shape: {X_train.shape}")

Model

In [None]:
# EEGNet + attention and residual connections
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.scale = self.head_dim ** -0.5
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # B, HW, C
        
        qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        out = (attn @ v).transpose(1, 2).reshape(B, -1, C)
        out = self.proj(out)
        
        return out.transpose(1, 2).reshape(B, C, H, W)

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv(x_cat)
        return x * self.sigmoid(x_cat)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class StateOfTheArtEEGNet(nn.Module):
    def __init__(self, input_shape, num_classes=4, dropout_rate=0.25, F1=32, D=2, F2=64):
        super().__init__()
        
        self.input_channels, self.input_time_points = input_shape
        self.F1 = F1
        self.D = D
        self.F2 = F2
        
        # Block 1: Temporal Convolution with residual connections
        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        
        # Block 2: Depthwise Convolution with attention
        self.conv2 = nn.Conv2d(F1, D * F1, (self.input_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(D * F1)
        self.elu1 = nn.ELU()
        self.avgpool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout2d(dropout_rate)
        
        # Channel attention
        self.channel_att1 = ChannelAttention(D * F1)
        
        # Block 3: Separable Convolution
        self.conv3 = nn.Conv2d(D * F1, D * F1, (1, 16), padding=(0, 8), groups=D * F1, bias=False)
        self.conv4 = nn.Conv2d(D * F1, F2, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.elu2 = nn.ELU()
        self.avgpool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout2d(dropout_rate)
        
        # Additional residual blocks
        self.res_block1 = ResidualBlock(F2, F2)
        self.res_block2 = ResidualBlock(F2, F2)
        
        # Spatial attention
        self.spatial_att = SpatialAttention()
        
        # Multi-head attention
        self.mha = MultiHeadAttention(F2, num_heads=4)
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # classifier with multiple branches
        feature_size = F2
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes)
        )
        
        # Auxiliary classifier for regularization
        self.aux_classifier = nn.Sequential(
            nn.Linear(feature_size, 32),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(32, num_classes)
        )
        
    def forward(self, x):
        # Block 1
        x = self.bn1(self.conv1(x))
        
        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.elu1(x)
        x = self.avgpool1(x)
        x = self.dropout1(x)
        
        # Apply channel attention
        x = self.channel_att1(x)
        
        # Block 3
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.bn3(x)
        x = self.elu2(x)
        x = self.avgpool2(x)
        x = self.dropout2(x)
        
        # Residual blocks
        x = self.res_block1(x)
        x = self.res_block2(x)
        
        # Apply spatial attention
        x = self.spatial_att(x)
        
        # Multi-head attention
        x = self.mha(x)
        
        # Global average pooling
        features = self.global_avg_pool(x)
        features = features.view(features.size(0), -1)
        
        # Main classifier
        main_output = self.classifier(features)
        
        # Auxiliary classifier (for training regularization)
        aux_output = self.aux_classifier(features)
        
        return main_output, aux_output

Trainer Setup

In [None]:
# Training with multiple loss functions and optimization
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingLoss(nn.Module):
    def __init__(self, num_classes, smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        
    def forward(self, pred, target):
        with torch.no_grad():
            smooth_target = torch.zeros_like(pred)
            smooth_target.fill_(self.smoothing / (self.num_classes - 1))
            smooth_target.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        
        return F.kl_div(F.log_softmax(pred, dim=1), smooth_target, reduction='batchmean')

class ATrainer:
    def __init__(self, model, device, num_classes=4):
        self.model = model
        self.device = device
        self.num_classes = num_classes
        
        # Multiple loss functions
        self.criterion_ce = nn.CrossEntropyLoss()
        self.criterion_focal = FocalLoss(alpha=1, gamma=2)
        self.criterion_smooth = LabelSmoothingLoss(num_classes, smoothing=0.1)
        
        # Optimizers with different learning rates
        self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
        
        # Early stopping
        self.best_val_acc = 0
        self.patience = 15
        self.patience_counter = 0
        
        # Mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
        
    def train_epoch(self, train_loader, epoch):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.scaler:
                with torch.cuda.amp.autocast():
                    main_output, aux_output = self.model(data)
                    
                    # Combined loss
                    loss_ce = self.criterion_ce(main_output, target)
                    loss_focal = self.criterion_focal(main_output, target)
                    loss_smooth = self.criterion_smooth(main_output, target)
                    aux_loss = self.criterion_ce(aux_output, target)
                    
                    loss = 0.4 * loss_ce + 0.3 * loss_focal + 0.2 * loss_smooth + 0.1 * aux_loss
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                main_output, aux_output = self.model(data)
                
                loss_ce = self.criterion_ce(main_output, target)
                loss_focal = self.criterion_focal(main_output, target)
                loss_smooth = self.criterion_smooth(main_output, target)
                aux_loss = self.criterion_ce(aux_output, target)
                
                loss = 0.4 * loss_ce + 0.3 * loss_focal + 0.2 * loss_smooth + 0.1 * aux_loss
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
            
            total_loss += loss.item()
            pred = main_output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            if batch_idx % 50 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, '
                      f'Acc: {100.*correct/total:.2f}%')
        
        return total_loss / len(train_loader), 100. * correct / total
    
    def validate(self, val_loader):
        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                main_output, _ = self.model(data)
                
                val_loss += self.criterion_ce(main_output, target).item()
                pred = main_output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        val_acc = 100. * correct / total
        return val_loss / len(val_loader), val_acc
    
    def train(self, train_loader, val_loader, epochs=100):
        train_losses, train_accs = [], []
        val_losses, val_accs = [], []
        
        for epoch in range(epochs):
            # Training
            train_loss, train_acc = self.train_epoch(train_loader, epoch)
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            
            # Validation
            val_loss, val_acc = self.validate(val_loader)
            val_losses.append(val_loss)
            val_accs.append(val_acc)
            
            # Learning rate scheduling
            self.scheduler.step()
            
            print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            
            # Early stopping
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                torch.save(self.model.state_dict(), 'best_model.pth')
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
        
        return {
            'train_losses': train_losses,
            'train_accs': train_accs,
            'val_losses': val_losses,
            'val_accs': val_accs
        }

Model Initialization and Training

In [None]:
# Model initialization and training
input_shape = (X_train.shape[2], X_train.shape[3])  # (channels, time_points)
model = StateOfTheArtEEGNet(input_shape, num_classes=4, dropout_rate=0.3, F1=64, D=2, F2=128).to(device)

# Print model summary
print("Model architecture:")
print(model)

# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'\nTotal parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

In [None]:
# data loading with weighted sampling for class imbalance
from torch.utils.data import WeightedRandomSampler

# Calculate class weights for balanced sampling
class_counts = np.bincount(y_train.cpu().numpy())
class_weights = 1.0 / class_counts
sample_weights = class_weights[y_train.cpu().numpy()]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# Create data loaders with optimized parameters
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, 
                         num_workers=0, pin_memory=True)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                       num_workers=0, pin_memory=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                        num_workers=0, pin_memory=True)


print(f"Data loaders created:")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Training
trainer = ATrainer(model, device, num_classes=4)

print("Starting advanced training...")
history = trainer.train(train_loader, val_loader, epochs=80)

# Load best model
model.load_state_dict(torch.load('best_model.pth'))
print(f"Best validation accuracy: {trainer.best_val_acc:.2f}%")