<a href="https://colab.research.google.com/github/ekvirika/Facial-Expression-Recognition/blob/main/notebooks/02_baseline_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
# Install required packages
!pip install wandb torch torchvision pandas numpy matplotlib seaborn scikit-learn

# Set up Kaggle API
!pip install kaggle

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Using cached nvidia_cusolver_cu

In [11]:
# Upload your kaggle.json to Colab and run:
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/ColabNotebooks/kaggle_API_credentials/kaggle.json ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json

In [12]:
# Download the dataset
!kaggle competitions download -c challenges-in-representation-learning-facial-expression-recognition-challenge
!unzip -q challenges-in-representation-learning-facial-expression-recognition-challenge.zip


Downloading challenges-in-representation-learning-facial-expression-recognition-challenge.zip to /content
 88% 250M/285M [00:00<00:00, 408MB/s]
100% 285M/285M [00:00<00:00, 421MB/s]


In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import classification_report, confusion_matrix
import wandb
import time
from datetime import datetime
from sklearn.model_selection import train_test_split

In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
CONFIG = {
    'model_name': 'simple_cnn',
    'batch_size': 32,
    'learning_rate': 0.001,
    'epochs': 30,
    'image_size': 48,
    'num_classes': 7,
    'random_seed': 42
}

# Set random seeds for reproducibility
torch.manual_seed(CONFIG['random_seed'])
np.random.seed(CONFIG['random_seed'])

# Initialize wandb
wandb.init(
    project="facial-expression-recognition",
    name=f"{CONFIG['model_name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    config=CONFIG,
    job_type="training"
)

Using device: cpu


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mellekvirikashvili[0m ([33mellekvirikashvili-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Custom Dataset Class

In [3]:
# Custom Dataset Class
class FERDataset(Dataset):
    def __init__(self, dataframe, indices, transform=None):
        self.data = dataframe.iloc[indices].reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get pixel data and convert to image
        pixels = self.data.iloc[idx]['pixels']
        image = np.array(pixels.split(), dtype=np.uint8).reshape(48, 48)

        # Convert to PIL format for transforms
        image = image.astype(np.float32) / 255.0  # Normalize to [0,1]
        image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension

        if self.transform:
            image = self.transform(image)

        label = int(self.data.iloc[idx]['emotion'])

        return image, label

#  Define Simple CNN Model


In [4]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleCNN, self).__init__()

        # Feature extraction layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        # Calculate the size after convolutions and pooling
        # 48 -> 24 -> 12 -> 6 after three pooling operations
        self.fc1 = nn.Linear(128 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # Conv block 1
        x = self.pool(F.relu(self.conv1(x)))  # 48x48 -> 24x24

        # Conv block 2
        x = self.pool(F.relu(self.conv2(x)))  # 24x24 -> 12x12

        # Conv block 3
        x = self.pool(F.relu(self.conv3(x)))  # 12x12 -> 6x6

        # Flatten
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x


# Load data

In [15]:
# Load data
print("Loading data...")
train_df = pd.read_csv('train.csv')
train_indices, val_indices = train_test_split(
    range(len(train_df)),
    test_size=0.2,
    stratify=train_df['emotion'],
    random_state=42
)


# Create datasets (no augmentation for baseline)
train_transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

val_transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = FERDataset(train_df, train_indices, transform=train_transform)
val_dataset = FERDataset(train_df, val_indices, transform=val_transform)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Loading data...
Training samples: 22967
Validation samples: 5742


# Training

## Initialize model


In [16]:
model = SimpleCNN(num_classes=CONFIG['num_classes']).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model Parameters: {total_params:,} total, {trainable_params:,} trainable")

# Log model info
wandb.log({
    "model_parameters": total_params,
    "trainable_parameters": trainable_params
})


Model Parameters: 2,456,071 total, 2,456,071 trainable


In [17]:

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

## Training Function

In [18]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


## Validation function

In [19]:
# Validation function
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc, all_preds, all_targets


## Training Loop

In [20]:
print("\nStarting training...")
train_losses, train_accs = [], []
val_losses, val_accs = [], []
best_val_acc = 0.0

start_time = time.time()

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_loss, val_acc, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device)

    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    epoch_time = time.time() - epoch_start

    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_accuracy': train_acc,
        'val_loss': val_loss,
        'val_accuracy': val_acc,
        'epoch_time': epoch_time,
        'learning_rate': optimizer.param_groups[0]['lr']
    })

    print(f'Epoch [{epoch+1}/{CONFIG["epochs"]}]')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print(f'Time: {epoch_time:.2f}s')
    print('-' * 50)

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'simple_cnn_best.pth')
        print(f'New best validation accuracy: {best_val_acc:.2f}%')

total_time = time.time() - start_time
print(f'\nTraining completed in {total_time:.2f}s')
print(f'Best validation accuracy: {best_val_acc:.2f}%')


Starting training...
Batch 0/718, Loss: 1.9443
Batch 100/718, Loss: 1.8132
Batch 200/718, Loss: 1.5460
Batch 300/718, Loss: 1.5465
Batch 400/718, Loss: 1.8359
Batch 500/718, Loss: 1.6811
Batch 600/718, Loss: 1.4701
Batch 700/718, Loss: 1.5407
Epoch [1/30]
Train Loss: 1.6369, Train Acc: 35.04%
Val Loss: 1.4597, Val Acc: 44.03%
Time: 204.96s
--------------------------------------------------
New best validation accuracy: 44.03%
Batch 0/718, Loss: 1.5783
Batch 100/718, Loss: 1.4506
Batch 200/718, Loss: 1.2007
Batch 300/718, Loss: 1.3132
Batch 400/718, Loss: 1.4726
Batch 500/718, Loss: 1.2295
Batch 600/718, Loss: 1.2094
Batch 700/718, Loss: 1.6433
Epoch [2/30]
Train Loss: 1.3985, Train Acc: 46.47%
Val Loss: 1.3030, Val Acc: 50.23%
Time: 194.68s
--------------------------------------------------
New best validation accuracy: 50.23%
Batch 0/718, Loss: 1.4325
Batch 100/718, Loss: 1.1674
Batch 200/718, Loss: 0.9965
Batch 300/718, Loss: 1.3791
Batch 400/718, Loss: 1.2431
Batch 500/718, Loss: 1

KeyboardInterrupt: 

In [None]:

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_accs, label='Train Acc', color='blue')
plt.plot(val_accs, label='Val Acc', color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(np.array(train_accs) - np.array(val_accs), label='Acc Gap', color='green')
plt.xlabel('Epoch')
plt.ylabel('Train - Val Accuracy (%)')
plt.title('Overfitting Indicator')
plt.legend()
plt.grid(True)
plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('simple_cnn_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

## load bst model

In [None]:

# Load best model for final evaluation
model.load_state_dict(torch.load('simple_cnn_best.pth'))

# Final validation with best model
final_val_loss, final_val_acc, final_preds, final_targets = validate_epoch(
    model, val_loader, criterion, device
)

print(f'Final validation accuracy: {final_val_acc:.2f}%')

# Classification report
expression_mapping = {
    0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy',
    4: 'Sad', 5: 'Surprise', 6: 'Neutral'
}

class_names = [expression_mapping[i] for i in range(7)]
class_report = classification_report(
    final_targets, final_preds,
    target_names=class_names,
    output_dict=True
)

print("\nClassification Report:")
print(classification_report(final_targets, final_preds, target_names=class_names))


In [None]:

# Confusion Matrix
cm = confusion_matrix(final_targets, final_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Simple CNN')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.tight_layout()
plt.savefig('simple_cnn_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

# Analyze per-class performance
per_class_acc = []
for i in range(7):
    class_mask = np.array(final_targets) == i
    if class_mask.sum() > 0:
        class_acc = (np.array(final_preds)[class_mask] == i).mean() * 100
        per_class_acc.append(class_acc)
        print(f'{class_names[i]}: {class_acc:.1f}% ({class_mask.sum()} samples)')

# Log final results to wandb
wandb.log({
    'final_val_accuracy': final_val_acc,
    'final_val_loss': final_val_loss,
    'training_curves': wandb.Image('simple_cnn_training_curves.png'),
    'confusion_matrix': wandb.Image('simple_cnn_confusion_matrix.png'),
    'classification_report': class_report,
    'total_training_time': total_time,
    'best_epoch': np.argmax(val_accs) + 1
})

# Analyzis

In [None]:

# Analysis of results
print("\n" + "="*60)
print("SIMPLE CNN BASELINE ANALYSIS")
print("="*60)

# Check for overfitting/underfitting
final_train_acc = train_accs[-1]
acc_gap = final_train_acc - final_val_acc

print(f"\nOverfitting Analysis:")
print(f"Final Training Accuracy: {final_train_acc:.2f}%")
print(f"Final Validation Accuracy: {final_val_acc:.2f}%")
print(f"Accuracy Gap: {acc_gap:.2f}%")

if acc_gap > 10:
    print("🔴 OVERFITTING DETECTED: Large gap between train and validation accuracy")
    print("   - Model memorizing training data")
    print("   - Need regularization (dropout, data augmentation)")
elif acc_gap < 5:
    print("🟡 GOOD GENERALIZATION: Small gap between train and validation")
    if final_val_acc < 50:
        print("🔴 UNDERFITTING: Both accuracies are low")
        print("   - Model too simple for the task")
        print("   - Need more capacity or different architecture")
    else:
        print("✅ BALANCED MODEL: Good generalization")
else:
    print("🟡 MILD OVERFITTING: Moderate gap, could be improved")

# Learning curve analysis
print(f"\nLearning Curve Analysis:")
print(f"Best epoch: {np.argmax(val_accs) + 1}")
print(f"Early stopping could have saved {CONFIG['epochs'] - np.argmax(val_accs) - 1} epochs")

if val_accs[-1] < max(val_accs) * 0.95:
    print("🔴 VALIDATION ACCURACY DECLINING: Clear overfitting in later epochs")
else:
    print("✅ STABLE LEARNING: No significant decline in validation performance")

# Performance insights
worst_class = class_names[np.argmin(per_class_acc)]
best_class = class_names[np.argmax(per_class_acc)]

print(f"\nPer-class Performance:")
print(f"Best performing class: {best_class} ({max(per_class_acc):.1f}%)")
print(f"Worst performing class: {worst_class} ({min(per_class_acc):.1f}%)")
print(f"Performance variance: {np.std(per_class_acc):.1f}%")

# Model complexity analysis
print(f"\nModel Complexity:")
print(f"Total parameters: {total_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB (float32)")
print(f"Training time: {total_time:.1f}s ({total_time/CONFIG['epochs']:.1f}s/epoch)")

# Recommendations for next experiments
print(f"\n🔬 RECOMMENDATIONS FOR NEXT EXPERIMENTS:")
print(f"1. Architecture: Try deeper CNN with more layers")
print(f"2. Regularization: Add batch normalization, more dropout")
print(f"3. Data: Implement data augmentation to reduce overfitting")
print(f"4. Optimization: Try different learning rates, schedulers")
print(f"5. Loss: Consider weighted loss for class imbalance")

# Save experiment summary
experiment_summary = {
    'model_name': CONFIG['model_name'],
    'final_val_accuracy': final_val_acc,
    'best_val_accuracy': best_val_acc,
    'final_train_accuracy': final_train_acc,
    'overfitting_gap': acc_gap,
    'total_parameters': total_params,
    'training_time': total_time,
    'best_epoch': int(np.argmax(val_accs) + 1),
    'per_class_accuracy': dict(zip(class_names, per_class_acc)),
    'key_findings': {
        'overfitting_detected': acc_gap > 10,
        'underfitting_detected': final_val_acc < 40,
        'early_stopping_beneficial': val_accs[-1] < max(val_accs) * 0.95,
        'class_imbalance_impact': max(per_class_acc) - min(per_class_acc) > 20
    }
}

# Log summary
wandb.log({'experiment_summary': experiment_summary})

print(f"\n✅ Simple CNN baseline complete!")
print(f"📊 Results logged to Wandb")
print(f"💾 Best model saved as 'simple_cnn_best.pth'")

wandb.finish()

In [None]:

# Template for README documentation
print("\n" + "="*60)
print("DOCUMENTATION FOR README:")
print("="*60)
print(f"""
## Experiment 1: Simple CNN Baseline

### Hypothesis
Start with a minimal CNN architecture to establish a baseline performance.
Expected: Likely to underfit due to limited model capacity for complex facial expressions.

### Architecture
```python
SimpleCNN(
  (conv1): Conv2d(1, 32, kernel_size=3, padding=1)
  (conv2): Conv2d(32, 64, kernel_size=3, padding=1)
  (conv3): Conv2d(64, 128, kernel_size=3, padding=1)
  (pool): MaxPool2d(2, 2)
  (dropout): Dropout(0.5)
  (fc1): Linear(4608, 512)
  (fc2): Linear(512, 7)
)
```
- **Parameters**: {total_params:,}
- **Layers**: 3 conv + 2 FC layers
- **Regularization**: Dropout (0.5)

### Hyperparameters
- **Learning Rate**: {CONFIG['learning_rate']}
- **Batch Size**: {CONFIG['batch_size']}
- **Epochs**: {CONFIG['epochs']}
- **Optimizer**: Adam
- **Loss**: CrossEntropyLoss

### Results
- **Best Validation Accuracy**: {best_val_acc:.2f}%
- **Final Training Accuracy**: {final_train_acc:.2f}%
- **Overfitting Gap**: {acc_gap:.2f}%
- **Training Time**: {total_time:.1f}s

### Analysis
{'🔴 **OVERFITTING DETECTED**' if acc_gap > 10 else '🟡 **MILD OVERFITTING**' if acc_gap > 5 else '✅ **GOOD GENERALIZATION**'}
- Training accuracy ({final_train_acc:.1f}%) {'significantly higher than' if acc_gap > 10 else 'moderately higher than' if acc_gap > 5 else 'close to'} validation ({final_val_acc:.1f}%)
- {'Model is memorizing training data rather than learning generalizable features' if acc_gap > 10 else 'Some overfitting present but manageable' if acc_gap > 5 else 'Good balance between fitting and generalization'}

**Per-class Performance**:
- Best: {best_class} ({max(per_class_acc):.1f}%)
- Worst: {worst_class} ({min(per_class_acc):.1f}%)
- High variance ({np.std(per_class_acc):.1f}%) suggests class imbalance impact

### Key Findings
1. **Baseline Established**: {final_val_acc:.1f}% accuracy provides lower bound
2. **{'Overfitting' if acc_gap > 10 else 'Underfitting' if final_val_acc < 40 else 'Balanced'} Detected**: {'Need regularization techniques' if acc_gap > 10 else 'Need more model capacity' if final_val_acc < 40 else 'Good starting point'}
3. **Class Imbalance Impact**: {max(per_class_acc) - min(per_class_acc):.1f}% performance gap
4. **Training Efficiency**: {total_time/CONFIG['epochs']:.1f}s per epoch, converged around epoch {np.argmax(val_accs) + 1}

### Next Steps
1. **Regularization**: Add batch normalization, data augmentation
2. **Architecture**: Increase depth and width gradually
3. **Data Strategy**: Address class imbalance with weighted loss
4. **Optimization**: Experiment with learning rate scheduling
""")