In [1]:
import kagglehub
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image  # Import PIL for image handling

In [2]:
# Download dataset from Kaggle
path = kagglehub.dataset_download("msambare/fer2013")
print("Path to dataset files:", path)

# Emotion labels corresponding to the indices (0-6)
emotion_labels = {
    0: 'angry',
    1: 'disgust',
    2: 'fear',
    3: 'happy',
    4: 'sad',
    5: 'surprise',
    6: 'neutral'
}

# Model configuration
model_name = "google/vit-base-patch16-224"
NUM_CLASSES = 7  # FER2013 has 7 emotion classes

Downloading from https://www.kaggle.com/api/v1/datasets/download/msambare/fer2013?dataset_version_number=1...


100%|██████████| 60.3M/60.3M [00:04<00:00, 15.7MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/msambare/fer2013/versions/1


In [3]:
# Load the ViT model
vit_model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(vit_model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

# Define the same transforms as before (with data augmentation)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Custom dataset class
class FERDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.data = []
        self.labels = []
        self._load_data()

    def _load_data(self):
        for emotion_idx, emotion_name in enumerate(os.listdir(self.folder_path)):
            emotion_path = os.path.join(self.folder_path, emotion_name)
            for img_file in os.listdir(emotion_path):
                img_path = os.path.join(emotion_path, img_file)
                try:
                    img = Image.open(img_path).convert("RGB")
                    self.data.append(img)
                    self.labels.append(emotion_idx)
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Prepare datasets and dataloaders
train_path = os.path.join(path, "train")
test_path = os.path.join(path, "test")

train_dataset = FERDataset(train_path, transform=transform)
test_dataset = FERDataset(test_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
# Training loop
def train_one_epoch(dataloader):
    vit_model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = vit_model(images).logits

        # Compute loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, preds = torch.max(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    print(f"Training Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy

# Validation loop
def validate(dataloader):
    vit_model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = vit_model(images).logits

            # Compute loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Compute accuracy
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy

In [None]:
# Training process
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_loss, train_acc = train_one_epoch(train_loader)
    val_loss, val_acc = validate(test_loader)

# Save the trained model
torch.save(vit_model.state_dict(), "vit_finetuned.pth")
print("Model saved successfully!")

# Load the model for inference
vit_model.load_state_dict(torch.load("vit_finetuned.pth"))
vit_model.to(device)
vit_model.eval()
print("Model loaded and ready for inference.")

Epoch 1/10
