In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install wandb torch torchvision pandas numpy matplotlib seaborn

# Set up Kaggle API
!pip install kaggle

In [None]:
# Upload your kaggle.json to Colab and run:
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/ColabNotebooks/kaggle_API_credentials/kaggle.json ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
# Download the dataset
!kaggle competitions download -c challenges-in-representation-learning-facial-expression-recognition-challenge
!unzip -q challenges-in-representation-learning-facial-expression-recognition-challenge.zip

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import wandb
from tqdm.auto import tqdm
from transformers import ViTForImageClassification
from transformers import get_cosine_schedule_with_warmup

In [None]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


In [None]:

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# 2. Data Loading and Preprocessing

In [None]:

class FER2013Dataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.data = dataframe
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get image and label
        pixels = self.data.iloc[idx]['pixels']
        image = np.array(pixels.split(), dtype='uint8')
        image = image.reshape(48, 48, 1).astype('float32') / 255.0
        
        # Convert to 3 channels for ViT
        image = np.repeat(image, 3, axis=-1)
        
        label = self.data.iloc[idx]['emotion']
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


In [None]:

# Load and prepare data
print("Loading data...")
data_path = 'train.csv'  # Update this path
df = pd.read_csv(data_path)


In [None]:

# Split data
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['emotion'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['emotion'])

# Define transforms
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [None]:

# Create datasets
train_dataset = FER2013Dataset(train_df, train_transform)
val_dataset = FER2013Dataset(val_df, val_transform)
test_dataset = FER2013Dataset(test_df, val_transform)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


# 3. Model Definition


In [None]:

# 3. Model Definition
class VisionTransformer(nn.Module):
    def __init__(self, num_classes=7, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.model = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:

# Initialize model
model = VisionTransformer(num_classes=7).to(device)
print("Model architecture:")
print(model)


# 4. Training Setup


In [None]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)

# Learning rate scheduler
total_steps = len(train_loader) * 20  # 20 epochs
warmup_steps = int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

In [None]:

# Model and training configuration
config = {
    "model_name": "google/vit-base-patch16-224",
    "num_classes": 7,
    "batch_size": batch_size,
    "learning_rate": 2e-5,
    "weight_decay": 1e-4,
    "epochs": 20,
    "warmup_ratio": 0.1,
    "image_size": 224,
    "dataset": "FER2013",
    "architecture": "VisionTransformer",
    "optimizer": "AdamW",
    "scheduler": "CosineWithWarmup",
    "loss_function": "CrossEntropyLoss"
}


In [None]:

# Initialize Weights & Biases with more comprehensive config
wandb.init(
    project="facial-expression-recognition",
    config=config,
    name=f"vit-fer2013-{wandb.util.generate_id()}",
    tags=["vision-transformer", "fer2013", "facial-expression-recognition"]
)


In [None]:

# Log dataset statistics
class_counts = df['emotion'].value_counts().to_dict()
wandb.log({"class_distribution": wandb.plot.bar(
    wandb.Table(
        data=[[k, v] for k, v in class_counts.items()],
        columns=["class", "count"]
    ),
    "class",
    "count",
    title="Class Distribution"
)})

In [None]:

# 5. Training Loop
def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        if scheduler:
            scheduler.step()
        
        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({
            'loss': running_loss / total,
            'acc': 100. * correct / total
        })
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = 100. * correct / len(dataloader.dataset)
    
    return epoch_loss, epoch_acc


In [None]:

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({
                'loss': running_loss / total,
                'acc': 100. * correct / total
            })
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = 100. * correct / len(dataloader.dataset)
    
    return epoch_loss, epoch_acc, all_preds, all_labels


## Training Loop


In [None]:

# Training loop
print("Starting training...")
best_val_acc = 0.0

for epoch in range(20):
    print(f"\nEpoch {epoch+1}/20")
    print("-" * 10)
    
    # Train for one epoch
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, scheduler
    )
    
    # Validate
    val_loss, val_acc, _, _ = validate(
        model, val_loader, criterion, device
    )
    
    # Log metrics
    metrics = {
        "epoch": epoch + 1,
        "train/loss": train_loss,
        "train/accuracy": train_acc,
        "val/loss": val_loss,
        "val/accuracy": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr']
    }
    
    # Log some sample predictions every 5 epochs
    if (epoch + 1) % 5 == 0:
        # Get sample predictions
        model.eval()
        sample_inputs, sample_labels = next(iter(val_loader))
        sample_inputs = sample_inputs[:8].to(device)
        sample_labels = sample_labels[:8].to(device)
        
        with torch.no_grad():
            outputs = model(sample_inputs)
            _, preds = torch.max(outputs.logits, 1)
        
        # Log sample images with predictions
        sample_images = []
        for i in range(len(sample_inputs)):
            img = sample_inputs[i].cpu().numpy().transpose(1, 2, 0)
            img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1]
            sample_images.append(wandb.Image(
                img,
                caption=f"Pred: {class_names[preds[i]]}, True: {class_names[sample_labels[i]]}"
            ))
        
        metrics["samples"] = sample_images
    
    wandb.log(metrics)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        
        # Save model checkpoint
        checkpoint_path = f"vit_model_epoch_{epoch+1}.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': val_acc,
            'train_accuracy': train_acc,
            'config': config
        }, checkpoint_path)
        
        # Create and log model artifact
        model_artifact = wandb.Artifact(
            f"vit-fer2013-model",
            type="model",
            description=f"Vision Transformer trained on FER2013 - Epoch {epoch+1} - Val Acc: {val_acc:.2f}%",
            metadata={"val_accuracy": val_acc, "epoch": epoch + 1, **config}
        )
        model_artifact.add_file(checkpoint_path)
        wandb.log_artifact(model_artifact)
        
        # Clean up
        if os.path.exists(checkpoint_path):
            os.remove(checkpoint_path)
            
        print(f"Saved best model at epoch {epoch+1} with val acc: {val_acc:.2f}%")


# 6. Evaluation


In [None]:
print("\nEvaluating on test set...")
# Load the best model from wandb if available, otherwise use local
best_model_path = "best_vit_model.pth"
if not os.path.exists(best_model_path):
    # Try to download the best model from wandb
    api = wandb.Api()
    try:
        artifact = api.artifact(f'{wandb.run.entity}/{wandb.run.project}/model-vit-fer2013-model:latest')
        artifact_dir = artifact.download()
        best_model_path = os.path.join(artifact_dir, os.listdir(artifact_dir)[0])
    except:
        print("Could not load model from wandb, using local model if exists")

if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']} with val acc: {checkpoint['val_accuracy']:.2f}%")
else:
    print("No saved model found, using current model")

test_loss, test_acc, test_preds, test_labels = validate(
    model, test_loader, criterion, device
)


In [None]:

# Classification report
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, target_names=class_names))

# Confusion matrix
def plot_confusion_matrix(cm, class_names):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

cm = confusion_matrix(test_labels, test_preds)
plot_confusion_matrix(cm, class_names)

# Log test results to wandb
test_metrics = {
    "test/loss": test_loss,
    "test/accuracy": test_acc,
    "test/confusion_matrix": wandb.plot.confusion_matrix(
        preds=test_preds,
        y_true=test_labels,
        class_names=class_names,
        title="Test Confusion Matrix"
    )
}


In [None]:

# Log classification report
report = classification_report(test_labels, test_preds, target_names=class_names, output_dict=True)
wandb.log({"test/classification_report": wandb.Table(
    data=[[k, v["precision"], v["recall"], v["f1-score"], v["support"]] 
         for k, v in report.items() if k in class_names],
    columns=["class", "precision", "recall", "f1-score", "support"]
)})

# Log per-class metrics
for cls in class_names:
    if cls in report:
        test_metrics.update({
            f"test/{cls}/precision": report[cls]["precision"],
            f"test/{cls}/recall": report[cls]["recall"],
            f"test/{cls}/f1": report[cls]["f1-score"]
        })

wandb.log(test_metrics)


In [None]:

# Log final model as artifact
final_model_path = "final_vit_model.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_acc,
    'config': config
}, final_model_path)

final_artifact = wandb.Artifact(
    "vit-fer2013-final-model",
    type="model",
    description="Final Vision Transformer model trained on FER2013",
    metadata={"test_accuracy": test_acc, **config}
)
final_artifact.add_file(final_model_path)
wandb.log_artifact(final_artifact)

# Clean up
if os.path.exists(final_model_path):
    os.remove(final_model_path)

# Log hyperparameter optimization summary
wandb.define_metric("val/accuracy", summary="max")
wandb.define_metric("val/loss", summary="min")
wandb.define_metric("train/accuracy", summary="max")
wandb.define_metric("train/loss", summary="min")




In [None]:
# Mark the run as completed
wandb.finish()