# Experiment 01

- Loss function: ArcFace
- SwimB
- closed set

In [1]:
import os
import sys
import math
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm.notebook import tqdm

In [2]:
sys.path.append('..')

from src.dataset import SeaTurtleDataset
from src.arcface import ArcFace
from src.utils import get_device

In [3]:
# --- Configuration ---
BATCH_SIZE = 10
IMG_SIZE = 224

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
dataset_dir='../data'

train_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_train.csv")
eval_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_eval.csv")
test_csv_path = os.path.join(dataset_dir, "metadata_splits_filtered_closed_test.csv")

train_dataset = SeaTurtleDataset(annotations_file=train_csv_path, img_dir=dataset_dir, transform=train_transform)
eval_dataset = SeaTurtleDataset(annotations_file=eval_csv_path, img_dir=dataset_dir, transform=test_transform)
test_dataset = SeaTurtleDataset(annotations_file=test_csv_path, img_dir=dataset_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
EMBEDDING_SIZE = 512
NUM_CLASSES = train_dataset.img_annotations['identity'].nunique()
EPOCHS = 10
LEARNING_RATE = 1e-4

device = get_device()
model_save_path = '../models/filtered_closed_arcface_swin_b.pth'

print(f'Using device: {device}')

# --- Swin B Backbone Model ---
model = models.swin_b(weights=models.Swin_B_Weights.IMAGENET1K_V1)
# Replace the final classification head with a layer that produces the embeddings
model.head = nn.Linear(model.head.in_features, EMBEDDING_SIZE)
print("Original head:", model.head)
print(model.head.in_features)
model.to(device)

# --- ArcFace Head & Loss Func ---
metric = ArcFace(num_classes=NUM_CLASSES, embedding_size=EMBEDDING_SIZE, scale=30.0, margin=0.50).to(device)
criterion = nn.CrossEntropyLoss()

# --- Optimizer ---
optimizer = optim.AdamW(
    list(model.parameters()) + list(metric.parameters()),
    lr=LEARNING_RATE
)

# --- Scheduler ---
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

Using device: mps
Original head: Linear(in_features=1024, out_features=512, bias=True)
1024


In [8]:
# --- Training Loop ---
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, (images, labels, _identities) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        features = model(images)
        output = metric(features, labels)
        loss = criterion(output, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'loss': running_loss / (i + 1)})

    scheduler.step()

    epoch_loss = running_loss / len(train_loader)
    
    # Evaluation
    model.eval()
    correct = 0
    total = 0

    # Build gallery from training set
    gallery_embeddings = []
    gallery_labels = []
    with torch.no_grad():
        for images, labels, _ in train_loader:
            images = images.to(device)
            embeddings = model(images)
            embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
            gallery_embeddings.append(embeddings.cpu())
            gallery_labels.append(labels)

    gallery_embeddings = torch.cat(gallery_embeddings, dim=0)
    gallery_labels = torch.cat(gallery_labels, dim=0)

    # Evaluate on eval set
    with torch.no_grad():
        for images, labels, _identities in eval_loader:
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)
            embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
            
            # Compute cosine similarity with gallery
            similarities = torch.mm(embeddings, gallery_embeddings.to(device).t())
            
            # Get kNN (k=1 - top-1) prediction
            _, predicted_indices = torch.max(similarities, 1)
            predicted = gallery_labels[predicted_indices.cpu()].to(device)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader):.4f}, Test Accuracy: {epoch_acc:.2f}%")
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), model_save_path)
        print("Saved best model.")


print(f"Finished Training. Best Test Accuracy: {best_acc:.2f}%")

Epoch 1/10:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 1/10, Loss: 16.9084, Test Accuracy: 20.00%
Saved best model.


Epoch 2/10:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 2/10, Loss: 14.5352, Test Accuracy: 16.84%


Epoch 3/10:   0%|          | 0/24 [00:00<?, ?it/s]

Epoch 3/10, Loss: 11.1536, Test Accuracy: 16.84%


Epoch 4/10:   0%|          | 0/24 [00:00<?, ?it/s]

KeyboardInterrupt: 