In [1]:
import os
import zipfile
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader,random_split
from torchvision import datasets, transforms
from transformers import SwinForImageClassification
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from google.colab import drive

In [2]:
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
base_dir = '/content/drive/MyDrive/MiniProject/Face_Model'
folders = {
    'base': base_dir,
    'data': os.path.join(base_dir, 'data'),
    'models': os.path.join(base_dir, 'models'),
    'outputs': os.path.join(base_dir, 'outputs'),
}
for f in folders.values():
    os.makedirs(f, exist_ok=True)

print("Folder structure created")

Folder structure created


In [4]:
datasets_zip = {
    "fer2013": os.path.join(folders['data'], 'dataset.zip'),
    "ckplus": os.path.join(folders['data'], 'CK+48.zip'),
}

extract_dir = "/content/datasets"
os.makedirs(extract_dir, exist_ok=True)

for name, zip_path in datasets_zip.items():
    if os.path.exists(zip_path):
        print(f"Extracting {name} ...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(extract_dir, name))
    else:
        print(f"Missing: {zip_path}")

Extracting fer2013 ...
Extracting ckplus ...


In [5]:
transform_train = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

In [6]:
fer_path = os.path.join(extract_dir, "fer2013")
ck_path = os.path.join(extract_dir, "ckplus/CK+48")

train_dataset = datasets.ImageFolder(root=os.path.join(fer_path, "dataset/train"), transform=transform_train)
val_dataset = datasets.ImageFolder(root=os.path.join(fer_path, "dataset/valid"), transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
print("FER2013 classes:", train_dataset.classes)

FER2013 classes: ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']


In [7]:
ck_class_mapping = {
    'anger':'angry',
    'disgust':'disgust',
    'fear':'fear',
    'happy':'happy',
    'sadness':'sad',
    'surprise':'surprise',
    'contempt':'neutral'
}

for old_name, new_name in ck_class_mapping.items():
    old_path = os.path.join(ck_path, old_name)
    new_path = os.path.join(ck_path, new_name)
    if os.path.exists(old_path) and not os.path.exists(new_path):
        os.rename(old_path, new_path)

ck_dataset = datasets.ImageFolder(root=ck_path, transform=transform_train)
train_size = int(0.8 * len(ck_dataset))
val_size   = len(ck_dataset) - train_size
ck_train, ck_val = random_split(ck_dataset, [train_size, val_size])

ck_train_loader = DataLoader(ck_train, batch_size=16, shuffle=True, num_workers=2)
ck_val_loader   = DataLoader(ck_val, batch_size=16, shuffle=False, num_workers=2)

print("CK+ classes:", ck_dataset.classes)

CK+ classes: ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = len(train_dataset.classes)
model = SwinForImageClassification.from_pretrained(
    "microsoft/swin-tiny-patch4-window7-224",
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
best_acc = 0.0
patience = 10
no_improve_epochs = 0

for epoch in range(20):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"FER2013 Epoch {epoch+1}/20"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    train_acc = 100*correct/total
    print(f"Train Acc: {train_acc:.2f}%, Loss: {total_loss/total:.4f}")

    # Validation
    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs).logits
            preds = torch.argmax(outputs, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    val_acc = 100*val_correct/val_total
    print(f"Validation Acc: {val_acc:.2f}%")

    if val_acc >= 95.0:
      torch.save(model.state_dict(), os.path.join(folders['models'], 'best_swin_fer.pth'))
      print(f"Model saved at {val_acc:.2f}% accuracy!")
    else:
        no_improve_epochs += 1
        if no_improve_epochs >= patience:
            print("Early stopping triggered.")
            break

FER2013 Epoch 1/20: 100%|██████████| 898/898 [06:40<00:00,  2.24it/s]

Train Acc: 87.31%, Loss: 0.3547





Validation Acc: 69.98%


FER2013 Epoch 2/20: 100%|██████████| 898/898 [06:34<00:00,  2.28it/s]

Train Acc: 88.90%, Loss: 0.3092





Validation Acc: 69.27%


FER2013 Epoch 3/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 89.91%, Loss: 0.2814





Validation Acc: 70.23%


FER2013 Epoch 4/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 91.00%, Loss: 0.2482





Validation Acc: 69.80%


FER2013 Epoch 5/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 91.73%, Loss: 0.2333





Validation Acc: 69.84%


FER2013 Epoch 6/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 92.40%, Loss: 0.2110





Validation Acc: 69.11%


FER2013 Epoch 7/20: 100%|██████████| 898/898 [06:34<00:00,  2.28it/s]

Train Acc: 93.19%, Loss: 0.1897





Validation Acc: 69.02%


FER2013 Epoch 8/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 93.50%, Loss: 0.1851





Validation Acc: 69.60%


FER2013 Epoch 9/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 93.80%, Loss: 0.1718





Validation Acc: 70.24%


FER2013 Epoch 10/20: 100%|██████████| 898/898 [06:33<00:00,  2.28it/s]

Train Acc: 94.45%, Loss: 0.1560





Validation Acc: 70.09%
Early stopping triggered.


In [11]:
ck_num_classes = len(ck_dataset.classes)
if ck_num_classes != num_classes:
    model.classifier = nn.Linear(model.classifier.in_features, ck_num_classes)
    model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

best_acc = 0.0
patience = 3
no_improve_epochs = 0

for epoch in range(10):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in tqdm(ck_train_loader, desc=f"CK+ Fine-tune Epoch {epoch+1}/10"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = 100 * correct / total
    print(f"Fine-tune Train Acc: {acc:.2f}%, Loss: {total_loss / total:.4f}")

    # Validation
    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for imgs, labels in ck_val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs).logits
            preds = torch.argmax(outputs, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100 * val_correct / val_total
    print(f"CK+ Validation Acc: {val_acc:.2f}%")

    if val_acc > 95.0:
        best_acc = val_acc
        torch.save(model.state_dict(), os.path.join(folders['models'], 'best_swin_ck.pth'))
        print(f"Model saved at {val_acc:.2f}% accuracy — stopping fine-tuning early!")
        break

    if val_acc > best_acc:
        best_acc = val_acc
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        if no_improve_epochs >= patience:
            print("Early stopping triggered due to no improvement.")
            break

print(f"Training completed. Best CK+ Validation Accuracy: {best_acc:.2f}%")


CK+ Fine-tune Epoch 1/10: 100%|██████████| 49/49 [00:19<00:00,  2.47it/s]

Fine-tune Train Acc: 88.78%, Loss: 0.3500





CK+ Validation Acc: 96.45%
Model saved at 96.45% accuracy — stopping fine-tuning early!
Training completed. Best CK+ Validation Accuracy: 96.45%
