In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install wandb torch torchvision timm pandas numpy matplotlib seaborn scikit-learn

# Set up Kaggle API
!pip install kaggle

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
import wandb
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import random

In [None]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Config
CONFIG = {
    'epochs': 10,
    'batch_size': 64,
    'lr': 1e-4,
    'img_size': 224,
    'num_classes': 7,
    'project': 'emotion-recognition',
    'run_name': 'vit-tiny-experiment'
}

# WandB init
wandb.init(project=CONFIG['project'], name=CONFIG['run_name'], config=CONFIG)


In [None]:

# Transforms
transform = transforms.Compose([
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Datasets
train_dataset = datasets.ImageFolder('data/train', transform=transform)
val_dataset = datasets.ImageFolder('data/val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)

# Model
model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=CONFIG['num_classes'])
model.to(device)

In [None]:

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        accuracy = 100. * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")
        wandb.log({"train_loss": epoch_loss, "train_accuracy": accuracy})
        
        evaluate_model(model, val_loader, criterion)

    return model

In [None]:
def evaluate_model(model, data_loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = running_loss / len(data_loader.dataset)
    accuracy = 100. * correct / total
    print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.close()

    class_report = classification_report(all_labels, all_preds, output_dict=True)
    wandb.log({"classification_report": class_report})

    # Log a few example predictions
    class_names = data_loader.dataset.classes
    rand_indices = random.sample(range(len(all_preds)), 5)
    for i in rand_indices:
        img, label = data_loader.dataset[i]
        img_np = img.permute(1,2,0).numpy() * 0.5 + 0.5
        plt.imshow(img_np)
        plt.title(f"True: {class_names[label]} | Pred: {class_names[all_preds[i]]}")
        plt.axis('off')
        wandb.log({f"sample_prediction_{i}": wandb.Image(plt)})
        plt.close()

    return avg_loss, accuracy

# Train
model = train_model(model, train_loader, val_loader, criterion, optimizer, CONFIG['epochs'])

# Save model
torch.save(model.state_dict(), 'vit_tiny_final.pth')
wandb.save('vit_tiny_final.pth')
wandb.finish()
