# Experiment 01

- Loss function: ArcFace
- SwimB
- closed set

In [6]:
import os
import sys
import math
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm.notebook import tqdm

In [7]:
sys.path.append('..')

from src.dataset import SeaTurtleDataset
from src.arcface import ArcFaceLoss
from src.utils import get_device

In [8]:
# --- Configuration ---
BATCH_SIZE = 8
IMG_SIZE = 224

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
dataset_dir='../data'

train_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_train.csv")
eval_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_eval.csv")
test_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_test.csv")

train_dataset = SeaTurtleDataset(annotations_file=train_csv_path, img_dir=dataset_dir, transform=train_transform)
eval_dataset = SeaTurtleDataset(annotations_file=eval_csv_path, img_dir=dataset_dir, transform=test_transform)
test_dataset = SeaTurtleDataset(annotations_file=test_csv_path, img_dir=dataset_dir, transform=test_transform)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
EMBEDDING_SIZE = 512
NUM_CLASSES = train_dataset.img_annotations['identity'].nunique()
EPOCHS = 1
LEARNING_RATE = 1e-4

device = get_device()
model_save_path = '../models/filtered_closed_arcface_swin_b.pth'

print(f'Using device: {device}')

# --- Swin B Backbone Model ---
model = models.swin_b(weights=models.Swin_B_Weights.IMAGENET1K_V1)
# Replace the final classification head with a layer that produces the embeddings
model.head = nn.Linear(model.head.in_features, EMBEDDING_SIZE)
model.to(device)

# --- ArcFace Head & Loss Func ---
arcface_loss = ArcFaceLoss(in_features=EMBEDDING_SIZE, out_features=NUM_CLASSES, scale=30.0, margin=0.50).to(device)
criterion = nn.CrossEntropyLoss()

# --- Optimizer ---
optimizer = optim.AdamW(
    list(model.parameters()) + list(arcface_loss.parameters()),
    lr=LEARNING_RATE
)

# --- Scheduler ---
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# --- Training Loop ---
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, (images, labels, text_labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        display(images.shape, labels.shape)

        # Forward pass
        embeddings = model(images)
        print(embeddings.shape, labels.shape)
        outputs = arcface_loss(embeddings, labels)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'loss': running_loss / (i + 1)})

    scheduler.step()

    epoch_loss = running_loss / len(train_loader)
    
    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels, _ in eval_loader:
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)
            # For evaluation, we can just use the embeddings and a KNN classifier, but for simplicity,
            # let's use the ArcFace output directly with a simple classification check.
            # A more accurate reproduction would involve creating a gallery of embeddings from the training set.
            outputs = arcface_loss(embeddings, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader):.4f}, Test Accuracy: {epoch_acc:.2f}%")
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), model_save_path)
        print("Saved best model.")


print(f"Finished Training. Best Test Accuracy: {best_acc:.2f}%")

Using device: mps


Epoch 1/1:   0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([8, 3, 224, 224])

torch.Size([8])

torch.Size([8, 512]) torch.Size([8])
torch.Size([8, 512]) torch.Size([8])


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (8x512 and 10x512)