# Experiment 01

- Loss function: ArcFace
- SwimB
- closed set

In [1]:
# !pip install ipywidgets

In [2]:
import os
import sys
import math
from pathlib import Path

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm.notebook import tqdm

In [3]:
from src.dataset import SeaTurtleDataset, download_dataset
from src.arcface import ArcFace
from src.utils import get_device

In [None]:
# --- Configuration ---
IMG_SIZE = 224
DATA_DIR = './data/seaturtleid2022-subset'

paths = download_dataset()

img_dir=paths['images_path']

model_dir = './models'
Path(model_dir).mkdir(parents=True, exist_ok=True)
model_save_path = f'{model_dir}/filtered_closed_arcface_swin_b.pth'

train_csv_path = os.path.join(DATA_DIR, "metadata_closed_set_splits_train.csv")
eval_csv_path = os.path.join(DATA_DIR, "metadata_closed_set_splits_valid.csv")
test_csv_path = os.path.join(DATA_DIR, "metadata_closed_set_splits_test.csv")

Dataset downloaded and extracted to: /Users/nhut/.cache/kagglehub/datasets/wildlifedatasets/seaturtleid2022/versions/4


In [11]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [12]:
BATCH_SIZE = 10

train_dataset = SeaTurtleDataset(
    annotations_file=train_csv_path, img_dir=img_dir, transform=train_transform)
eval_dataset = SeaTurtleDataset(
    annotations_file=eval_csv_path, img_dir=img_dir, transform=test_transform)
test_dataset = SeaTurtleDataset(
    annotations_file=test_csv_path, img_dir=img_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [13]:
EMBEDDING_SIZE = 512
NUM_CLASSES = train_dataset.img_annotations['identity'].nunique()
EPOCHS = 10
LEARNING_RATE = 1e-4

device = get_device()

print(f'Using device: {device}')

# --- Swin B Backbone Model ---
model = models.swin_b(weights=models.Swin_B_Weights.IMAGENET1K_V1)
# Replace the final classification head with a layer that produces the embeddings
model.head = nn.Linear(model.head.in_features, EMBEDDING_SIZE)
print("Original head:", model.head)
print(model.head.in_features)
model.to(device)

# --- ArcFace Head & Loss Func ---
metric = ArcFace(num_classes=NUM_CLASSES, embedding_size=EMBEDDING_SIZE, scale=30.0, margin=0.50).to(device)
criterion = nn.CrossEntropyLoss()

# --- Optimizer ---
optimizer = optim.AdamW(
    list(model.parameters()) + list(metric.parameters()),
    lr=LEARNING_RATE
)

# --- Scheduler ---
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

Using device: mps
Original head: Linear(in_features=1024, out_features=512, bias=True)
1024


In [14]:
# --- Training Loop ---
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, (images, labels, _identities) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        features = model(images)
        output = metric(features, labels)
        loss = criterion(output, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'loss': running_loss / (i + 1)})

    scheduler.step()

    epoch_loss = running_loss / len(train_loader)
    
    # Evaluation
    model.eval()
    correct = 0
    total = 0

    # Build gallery from training set
    gallery_embeddings = []
    gallery_labels = []
    with torch.no_grad():
        for images, labels, _ in train_loader:
            images = images.to(device)
            embeddings = model(images)
            embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
            gallery_embeddings.append(embeddings.cpu())
            gallery_labels.append(labels)

    gallery_embeddings = torch.cat(gallery_embeddings, dim=0)
    gallery_labels = torch.cat(gallery_labels, dim=0)

    # Evaluate on eval set
    with torch.no_grad():
        for images, labels, _identities in eval_loader:
            images = images.to(device)
            labels = labels.to(device)
            embeddings = model(images)
            embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
            
            # Compute cosine similarity with gallery
            similarities = torch.mm(embeddings,
                                    gallery_embeddings.to(device).t())
            
            # Get kNN (k=1 - top-1) prediction
            _, predicted_indices = torch.max(similarities, 1)
            predicted = gallery_labels[predicted_indices.cpu()].to(device)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader):.4f}, Eval Accuracy: {epoch_acc:.2f}%")
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), model_save_path)
        print("Saved best model.")


print(f"Finished Training. Best Test Accuracy: {best_acc:.2f}%")

Epoch 1/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 1/10, Loss: 16.9239, Eval Accuracy: 29.27%
Saved best model.


Epoch 2/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 2/10, Loss: 15.3125, Eval Accuracy: 35.37%
Saved best model.


Epoch 3/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 3/10, Loss: 13.3351, Eval Accuracy: 40.24%
Saved best model.


Epoch 4/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 4/10, Loss: 9.7785, Eval Accuracy: 36.59%


Epoch 5/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 5/10, Loss: 7.2563, Eval Accuracy: 40.24%


Epoch 6/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 6/10, Loss: 4.3317, Eval Accuracy: 42.68%
Saved best model.


Epoch 7/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 7/10, Loss: 2.6223, Eval Accuracy: 43.90%
Saved best model.


Epoch 8/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 8/10, Loss: 1.3476, Eval Accuracy: 42.68%


Epoch 9/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 9/10, Loss: 0.9109, Eval Accuracy: 45.12%
Saved best model.


Epoch 10/10:   0%|          | 0/21 [00:00<?, ?it/s]

Epoch 10/10, Loss: 0.7091, Eval Accuracy: 45.12%
Finished Training. Best Test Accuracy: 45.12%
