## Import libraries

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import timm

## Load dataset

In [None]:
# Define data paths
train_data_dir = "../data/fer-2013/train"
test_data_dir = "../data/fer-2013/test"

# Image size
image_size = (48, 48)

# Data transforms
train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),  # Normalize pixel values to [-1, 1]
])

test_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# Load the datasets using ImageFolder
train_dataset = ImageFolder(root=train_data_dir, transform=train_transform)
test_dataset = ImageFolder(root=test_data_dir, transform=test_transform)

# DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Check the number of classes and their corresponding indices
class_names = train_dataset.classes
class_indices = train_dataset.class_to_idx

print("Class Names:", class_names)
print("Class Indices:", class_indices)


In [None]:
# Function to show a batch of images
def show_images(images):
    plt.figure(figsize=(10, 5))
    n = min(len(images), 8)
    for i in range(n):
        plt.subplot(1, n, i + 1)
        plt.imshow(images[i][0], cmap='gray')
        plt.axis('off')
    plt.show()

# Iterate through the DataLoader and display a batch of images
for batch_images, _ in train_loader:
    show_images(batch_images)
    break  # Show only one batch of images for preview

## Model Creation

In [None]:
# Load a pre-trained timm model for transfer learning
model_name = "resnet18"
num_classes = 7  # Number of emotion classes in the fer-2013 dataset
model = timm.create_model(model_name, pretrained=True)

# Change the first layer to accept 1 input channel (instead of 3 for RGB)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Change the last fully connected layer for our task
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)  

## Model Training

In [None]:
# Step 3: Model Training
device = torch.device("cpu")
batch_size = 64
epochs = 10
lr = 0.001

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
        print(f"Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")
        

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss:.4f}")

# Step 4: Model Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")


In [None]:
# Save model
torch.save(model, "{}_{}_{}.pth".format(model_name, epoch, lr))

## Prediction on New Images of Niki :)

In [None]:
# Step 5: Prediction on New Pictures
# Assuming you have a list of image file paths in 'new_images'
from PIL import Image

image_path = '../data/test_predict/happy.png'

new_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

model.eval()
with torch.no_grad():
    image = Image.open(image_path)
    image = new_transform(image).unsqueeze(0).to(device)
    output = model(image)
    _, predicted = torch.max(output.data, 1)
    emotion_labels = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
    print(f"Predicted Emotion: {emotion_labels[predicted.item()]}")