In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation, ColorJitter
from timm import create_model
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from PIL import Image
from torch.amp import autocast, GradScaler  # Updated GradScaler

# ========================= Configuration =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 7  # FER2013 has 7 emotions
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
FER_TRAIN_DIR = "/kaggle/input/fer2013/train/"
FER_TEST_DIR = "/kaggle/input/fer2013/test/"
CLASS_NAMES = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprise"]

# ========================= Dataset Preparation =========================
class FERDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Collect all image paths and their corresponding labels
        for label, class_name in enumerate(CLASS_NAMES):
            class_dir = os.path.join(img_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert("RGB")  # Ensure all images are RGB
        if self.transform:
            image = self.transform(image)
        return image, label

# ========================= Transformations =========================
train_transform = Compose([
    Resize((224, 224)),
    RandomHorizontalFlip(p=0.5),
    RandomRotation(degrees=10),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    ToTensor(),
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

test_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

# ========================= Data Loaders =========================
train_dataset = FERDataset(img_dir=FER_TRAIN_DIR, transform=train_transform)
test_dataset = FERDataset(img_dir=FER_TEST_DIR, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# ========================= ViT Model =========================
class ViTEmotionModel(nn.Module):
    def __init__(self, num_classes):
        super(ViTEmotionModel, self).__init__()
        self.model = create_model('vit_base_patch16_224', pretrained=True)
        self.model.head = nn.Linear(self.model.head.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

model = ViTEmotionModel(NUM_CLASSES).to(DEVICE)

# ========================= Training and Validation Functions =========================
def train_one_epoch(model, loader, criterion, optimizer, scaler):
    model.train()
    running_loss = 0.0
    all_labels = []
    all_preds = []
    for images, labels in tqdm(loader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        with autocast(device_type="cuda"):  # Fixed autocast
            outputs = model(images)
            loss = criterion(outputs, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    return running_loss / len(loader), accuracy

def evaluate_model(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            with autocast(device_type="cuda"):  # Fixed autocast
                outputs = model(images)
                loss = criterion(outputs, labels)

            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    return running_loss / len(loader), accuracy, all_labels, all_preds

# ========================= Class Weights for Imbalance =========================
class_counts = [len(os.listdir(os.path.join(FER_TRAIN_DIR, class_name))) for class_name in CLASS_NAMES]
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float32)
class_weights = class_weights.to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)

# ========================= Optimizer and Scheduler =========================
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()  # Updated to avoid deprecation warning

# ========================= Training =========================
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler)
    val_loss, val_acc, _, _ = evaluate_model(model, test_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    scheduler.step()

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

# ========================= Evaluation =========================
val_loss, val_acc, val_labels, val_preds = evaluate_model(model, test_loader, criterion)
print(f"Test Accuracy: {val_acc:.4f}")

# Classification Report
print("Classification Report:")
print(classification_report(val_labels, val_preds, target_names=CLASS_NAMES))

# Confusion Matrix
cm = confusion_matrix(val_labels, val_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, cmap="Blues")
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()

# ========================= Save Model =========================
torch.save(model.state_dict(), "vit_emotion_model.pth")