In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import classification_report, confusion_matrix
import wandb
import time
from datetime import datetime
import seaborn as sns


In [None]:

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration - Building upon simple CNN lessons
CONFIG = {
    'model_name': 'deeper_cnn_bn',
    'batch_size': 32,
    'learning_rate': 0.001,
    'epochs': 40,  # More epochs since we expect slower convergence
    'image_size': 48,
    'num_classes': 7,
    'random_seed': 42,
    'weight_decay': 1e-4,  # L2 regularization
    'dropout_rate': 0.5
}

# Set random seeds
torch.manual_seed(CONFIG['random_seed'])
np.random.seed(CONFIG['random_seed'])

# Initialize wandb
wandb.init(
    project="facial-expression-recognition",
    name=f"{CONFIG['model_name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    config=CONFIG,
    job_type="training"
)


In [None]:

# Reuse dataset class from previous experiment
class FERDataset(Dataset):
    def __init__(self, dataframe, indices, transform=None):
        self.data = dataframe.iloc[indices].reset_index(drop=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        pixels = self.data.iloc[idx]['pixels']
        image = np.array(pixels.split(), dtype=np.uint8).reshape(48, 48)
        image = image.astype(np.float32) / 255.0
        image = torch.from_numpy(image).unsqueeze(0)
        
        if self.transform:
            image = self.transform(image)
            
        label = int(self.data.iloc[idx]['emotion'])
        return image, label


In [None]:

# Define Deeper CNN with Batch Normalization
class DeeperCNN(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.5):
        super(DeeperCNN, self).__init__()
        
        # First conv block
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        
        # Second conv block
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        # Third conv block
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        
        # Fourth conv block
        self.conv7 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(256)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.dropout = nn.Dropout(dropout_rate)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn_fc2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # Block 1: 48x48 -> 24x24
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        
        # Block 2: 24x24 -> 12x12
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        
        # Block 3: 12x12 -> 6x6
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.pool(x)
        
        # Block 4: 6x6 -> 3x3
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.pool(x)
        
        # Global average pooling: 3x3 -> 1x1
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x


In [None]:

# Load data
print("Loading data...")
train_df = pd.read_csv('train.csv')
train_indices = np.load('train_indices.npy')
val_indices = np.load('val_indices.npy')

# Create datasets with same transforms as baseline for fair comparison
train_transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

val_transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = FERDataset(train_df, train_indices, transform=train_transform)
val_dataset = FERDataset(train_df, val_indices, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


# Training

## Initialize model

In [None]:

# Initialize model
model = DeeperCNN(num_classes=CONFIG['num_classes'], dropout_rate=CONFIG['dropout_rate']).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model Parameters: {total_params:,} total, {trainable_params:,} trainable")
print(f"Parameter increase vs Simple CNN: ~{total_params/18000:.1f}x")  # Approximate simple CNN params

# Calculate class weights for balanced loss
train_labels = [train_df.iloc[i]['emotion'] for i in train_indices]
class_counts = np.bincount(train_labels)
class_weights = len(train_labels) / (len(class_counts) * class_counts)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print(f"Class weights: {class_weights}")


## Loss and optimizer

In [None]:

# Define loss and optimizer with weight decay
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

# Log model info
wandb.log({
    "model_parameters": total_params,
    "trainable_parameters": trainable_params,
    "class_weights": class_weights.tolist()
})

# Training function with gradient clipping
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total
    
    return epoch_loss, epoch_acc

# Validation function (same as before)
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100.0 * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_targets

# Training loop
print("\nStarting training...")
train_losses, train_accs = [], []
val_losses, val_accs = [], []
learning_rates = []
best_val_acc = 0.0
patience_counter = 0
early_stop_patience = 10

start_time = time.time()

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_acc)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    learning_rates.append(current_lr)
    
    epoch_time = time.time() - epoch_start
    
    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_accuracy': train_acc,
        'val_loss': val_loss,
        'val_accuracy': val_acc,
        'epoch_time': epoch_time,
        'learning_rate': current_lr
    })
    
    print(f'Epoch [{epoch+1}/{CONFIG["epochs"]}]')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print(f'LR: {current_lr:.6f}, Time: {epoch_time:.2f}s')
    print('-' * 60)
    
    # Save best model and early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'deeper_cnn_best.pth')
        print(f'New best validation accuracy: {best_val_acc:.2f}%')
        patience_counter = 0
    else:
        patience_counter += 1
        
    # Early stopping
    if patience_counter >= early_stop_patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

total_time = time.time() - start_time
actual_epochs = epoch + 1

print(f'\nTraining completed in {total_time:.2f}s ({actual_epochs} epochs)')
print(f'Best validation accuracy: {best_val_acc:.2f}%')

# Comprehensive training visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
axes[0,0].plot(train_losses, label='Train Loss', color='blue', alpha=0.7)
axes[0,0].plot(val_losses, label='Val Loss', color='red', alpha=0.7)
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].set_title('Training and Validation Loss')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# Accuracy curves
axes[0,1].plot(train_accs, label='Train Acc', color='blue', alpha=0.7)
axes[0,1].plot(val_accs, label='Val Acc', color='red', alpha=0.7)
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Accuracy (%)')
axes[0,1].set_title('Training and Validation Accuracy')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

# Overfitting indicator
acc_gap = np.array(train_accs) - np.array(val_accs)
axes[1,0].plot(acc_gap, label='Train - Val Accuracy', color='green', alpha=0.7)
axes[1,0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
axes[1,0].axhline(y=10, color='red', linestyle='--', alpha=0.5, label='Overfitting Threshold')
axes[1,0].set_xlabel('Epoch')
axes[1,0].set_ylabel('Accuracy Gap (%)')
axes[1,0].set_title('Overfitting Indicator')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Learning rate schedule
axes[1,1].plot(learning_rates, label='Learning Rate', color='orange', alpha=0.7)
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Learning Rate')
axes[1,1].set_title('Learning Rate Schedule')
axes[1,1].set_yscale('log')
axes[1,1].legend()
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('deeper_cnn_training_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# Load best model for evaluation
model.load_state_dict(torch.load('deeper_cnn_best.pth'))
final_val_loss, final_val_acc, final_preds, final_targets = validate_epoch(
    model, val_loader, criterion, device
)

print(f'Final validation accuracy: {final_val_acc:.2f}%')

# Detailed analysis
expression_mapping = {
    0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 
    4: 'Sad', 5: 'Surprise', 6: 'Neutral'
}

class_names = [expression_mapping[i] for i in range(7)]
class_report = classification_report(final_targets, final_preds, target_names=class_names, output_dict=True)

print("\nClassification Report:")
print(classification_report(final_targets, final_preds, target_names=class_names))

# Enhanced confusion matrix
cm = confusion_matrix(final_targets, final_preds)
plt.figure(figsize=(12, 10))

# Create normalized confusion matrix
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot both absolute and normalized
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=ax1)
ax1.set_title('Confusion Matrix - Absolute Values')
ax1.set_xlabel('Predicted')
ax1.set_ylabel('Actual')

sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=ax2)
ax2.set_title('Confusion Matrix - Normalized')
ax2.set_xlabel('Predicted')
ax2.set_ylabel('Actual')

plt.tight_layout()
plt.savefig('deeper_cnn_confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.show()

# Per-class performance analysis
per_class_acc = []
per_class_precision = []
per_class_recall = []

for i in range(7):
    class_acc = class_report[class_names[i]]['precision'] * 100
    class_precision = class_report[class_names[i]]['precision'] * 100
    class_recall = class_report[class_names[i]]['recall'] * 100
    
    per_class_acc.append(class_acc)
    per_class_precision.append(class_precision)
    per_class_recall.append(class_recall)
    
    print(f'{class_names[i]}: Precision {class_precision:.1f}%, Recall {class_recall:.1f}%')

# Performance comparison visualization
x = np.arange(len(class_names))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 8))
bars1 = ax.bar(x - width, per_class_precision, width, label='Precision', alpha=0.8)
bars2 = ax.bar(x, per_class_recall, width, label='Recall', alpha=0.8)
bars3 = ax.bar(x + width, [class_report[name]['f1