In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import math

# ----------------------------
# ArcFace Loss (Additive Angular Margin)
# ----------------------------
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        cosine = nn.functional.linear(nn.functional.normalize(input), nn.functional.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.clamp(torch.pow(cosine, 2), 0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1.0)

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

# ----------------------------
# Configs and Transforms
# ----------------------------
DATA_DIR = "D:\\FYP\\Vision Model\\dataset\\cropped"
IMG_SIZE = 112
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Dataset and Loader
train_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
num_classes = len(train_dataset.classes)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Number of classes: {num_classes}")
print(f"Number of images: {len(train_dataset)}")

# ----------------------------
# Backbone + ArcFace Head
# ----------------------------
backbone = models.resnet18(pretrained=True)
backbone.fc = nn.Identity()  # Remove final classification layer
feature_dim = 512  # Output of resnet18 without the final fc layer

arc_margin = ArcMarginProduct(in_features=feature_dim, out_features=num_classes).to(DEVICE)
model = backbone.to(DEVICE)

# ----------------------------
# Optimizer and Loss
# ----------------------------
optimizer = optim.Adam(list(model.parameters()) + list(arc_margin.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# ----------------------------
# Training Loop with tqdm Progress Bar
# ----------------------------
print("Starting training...")
model.train()

for epoch in range(EPOCHS):
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for inputs, labels in loop:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        features = model(inputs)
        outputs = arc_margin(features, labels)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loop.set_postfix(loss=loss.item(), acc=correct/total)

    epoch_acc = correct / total if total > 0 else 0
    print(f"Epoch {epoch+1}/{EPOCHS} | Total Loss: {running_loss:.4f} | Accuracy: {epoch_acc:.4f}")

# ----------------------------
# Save Model
# ----------------------------
torch.save({
    'backbone': model.state_dict(),
    'arc_margin': arc_margin.state_dict(),
    'classes': train_dataset.classes
}, "arcface_model.pth")

print("Model saved as arcface_model.pth")


Number of classes: 140
Number of images: 44982
Starting training...


                                                                                       

Epoch 1/10 | Total Loss: 16606.7963 | Accuracy: 0.0025


                                                                                 

Epoch 2/10 | Total Loss: 13389.5319 | Accuracy: 0.0000


                                                                                 

Epoch 3/10 | Total Loss: 13173.0019 | Accuracy: 0.0000


                                                                                 

Epoch 4/10 | Total Loss: 13006.5353 | Accuracy: 0.0000


                                                                                 

Epoch 5/10 | Total Loss: 12859.0443 | Accuracy: 0.0000


                                                                                    

Epoch 6/10 | Total Loss: 12746.4478 | Accuracy: 0.0000


                                                                                 

Epoch 7/10 | Total Loss: 12650.1103 | Accuracy: 0.0000


                                                                                 

Epoch 8/10 | Total Loss: 12564.0451 | Accuracy: 0.0000


                                                                                 

Epoch 9/10 | Total Loss: 12501.8613 | Accuracy: 0.0000


                                                                                  

Epoch 10/10 | Total Loss: 12445.6283 | Accuracy: 0.0000
Model saved as arcface_model.pth
