ViT Model: significantly overfits and has worse val loss and accuracy than the fine-tuned yolo classification head.  The images would presumably need to be higher definition for this to be more effective.

In [2]:
import os
import copy
import time
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
import timm
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
# 1. Paths & hyperparams
DATA_DIR   = Path("all_data")
TRAIN_DIR  = DATA_DIR / "train"
VAL_DIR    = DATA_DIR / "val"
NUM_CLASSES = len([d for d in TRAIN_DIR.iterdir() if d.is_dir()])
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(DEVICE)
BATCH_SIZE  = 32
IMAGE_SIZE  = 224
EPOCHS      = 7
LR          = 3e-4
MODEL_SAVE_PATH = "vit_emotion_best.pth"

mps


In [4]:
from torchvision import transforms

train_tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),   # duplicates L→RGB
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])
val_tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

In [5]:
# 3. Datasets & loaders
train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tf)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=val_tf)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

class_names = train_ds.classes
print("Classes:", class_names)

Classes: ['ahegao', 'angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']


In [6]:
# 4. Build ViT model (timm) with pretrained backbone + new head
model = timm.create_model(
    'vit_base_patch16_224',
    pretrained=True,
    num_classes=NUM_CLASSES
)
model.to(DEVICE)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [7]:
# 5. Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [16]:
# 6. Training & validation loops
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(1, EPOCHS+1):
    print(f"\nEpoch {epoch}/{EPOCHS}")
    for phase in ["train", "val"]:
        model.train() if phase=="train" else model.eval()
        loader = train_loader if phase=="train" else val_loader

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in tqdm(loader, desc=phase, leave=False):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase=="train"):
                outputs = model(inputs)
                loss    = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                if phase=="train":
                    loss.backward()
                    optimizer.step()

            running_loss     += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(loader.dataset)
        # <-- fix here: use float() instead of double()
        epoch_acc = running_corrects.float() / len(loader.dataset)

        # get a Python float if you need it
        acc_val = epoch_acc.item()

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {acc_val:.4f}")

        if phase=="val" and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    scheduler.step()


Epoch 1/7


                                                              

train Loss: 0.6281 Acc: 0.7619


                                                        

val Loss: 1.7793 Acc: 0.5460

Epoch 2/7


                                                                 

train Loss: 0.5947 Acc: 0.7729


                                                             

val Loss: 2.4478 Acc: 0.5243

Epoch 3/7


                                                          

KeyboardInterrupt: 

In [13]:
# 7. Save the best model
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"\nBest val Acc: {best_acc:.4f}")
print(f"Saved best model to `{MODEL_SAVE_PATH}`")


Best val Acc: 0.5460
Saved best model to `vit_emotion_best.pth`
