# 🐯 Tiger Re-Identification — Metric Learning (Paper-Style Eval, DINOv2)

This notebook implements a **metric-learning** pipeline for tiger re-ID with **paper-style evaluation**:

- **Train/val split** is **disjoint by identity** (no ID leakage).
- **Eval within val**: either **1 random gallery image per ID** (as used in many re-ID papers) **or** deterministic **first-per-ID** (for reproducibility).
- Backbone: **DINOv2 ViT** via `torch.hub` wrapped in **TigerDINO**.
- Training: **Balanced P×K sampler**, **batch-hard triplet** on **L2-normalized embeddings**.
- Metrics: **Rank-1** and **mAP** computed by cosine similarity.

In [1]:
# =====================
# Config
# =====================
CSV_PATH   = "./training_annotation/reid_list_train.csv"  # columns: tiger_id, image_filename
IMAGE_ROOT = "./train"  # where image files live

SEED       = 42
VAL_RATIO  = 0.2         # split-by-ID
IMG_SIZE   = 224

# Backbone / Head
DINO_MODEL = "dinov2_vitb14"
EMBED_DIM  = 512

# PK Sampler & Optim
P, K       = 8, 4
EPOCHS     = 20
LR         = 1e-4
WD         = 5e-2
MARGIN     = 0.3
BATCH_EVAL = 64
NUM_WORKERS= 4
SAVE_DIR   = "checkpoints"

# Evaluation mode inside VAL:
# If True -> "paper style": one RANDOM gallery image per ID (seeded)
# If False -> deterministic: first-per-ID goes to QUERY, rest in GALLERY (reliable across runs)
PAPER_RANDOM_GALLERY = True

In [5]:
# =====================
# Imports & setup
# =====================
import os, math, random
import numpy as np, pandas as pd
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from tqdm import tqdm

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

def seed_all(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
seed_all(SEED)

assert os.path.exists(CSV_PATH), f"CSV not found: {CSV_PATH}"
df = pd.read_csv(CSV_PATH, header=None, names=["tiger_id", "image_filename"])
print(f"Total images: {len(df)} | Unique tigers: {df['tiger_id'].nunique()}")

Device: cuda
Total images: 1887 | Unique tigers: 107


In [6]:
# =====================
# Split-by-ID into train / val (CLOSED-SET eval is within VAL only)
# =====================
uniq_ids = df["tiger_id"].unique()
rng = np.random.default_rng(SEED)
rng.shuffle(uniq_ids)
n_val = max(1, int(len(uniq_ids) * VAL_RATIO))
val_ids = set(uniq_ids[:n_val])
train_ids = set(uniq_ids[n_val:])

train_df = df[df["tiger_id"].isin(train_ids)].reset_index(drop=True)
val_df   = df[df["tiger_id"].isin(val_ids)].reset_index(drop=True)

print(f"Train IDs: {len(train_ids)}, Val IDs: {len(val_ids)}")
print(f"Train images: {len(train_df)}, Val images: {len(val_df)}")

Train IDs: 86, Val IDs: 21
Train images: 1428, Val images: 459


In [7]:
# =====================
# Transforms
# =====================
from torchvision import transforms
def make_transforms(img_size=IMG_SIZE):
    train_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)),
        transforms.RandomResizedCrop(img_size, scale=(0.6,1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    val_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.2)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    return train_tf, val_tf

train_tf, val_tf = make_transforms(IMG_SIZE)

In [8]:
# =====================
# Dataset with SHARED ID MAP (consistent labels everywhere)
# =====================
class TigerCSVReID(Dataset):
    def __init__(self, dataframe, image_root, transform=None, id_map=None):
        self.df = dataframe.copy().reset_index(drop=True)
        self.image_root = Path(image_root)
        self.transform = transform

        if id_map is None:
            uniq = sorted(self.df["tiger_id"].unique().tolist())
            self.id2idx = {tid:i for i,tid in enumerate(uniq)}
        else:
            self.id2idx = id_map  # shared mapping across train/val/gallery/query

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        img_path = self.image_root / str(r["image_filename"])
        with Image.open(img_path) as im:
            im = im.convert("RGB")
        if self.transform is not None:
            im = self.transform(im)
        seq_id = self.id2idx[r["tiger_id"]]
        return im, torch.tensor(seq_id, dtype=torch.long), r["tiger_id"], str(img_path)

# Global mapping so labels line up across datasets
ALL_IDS = sorted(df["tiger_id"].unique().tolist())
IDMAP = {tid:i for i,tid in enumerate(ALL_IDS)}

In [9]:
# =====================
# Balanced P×K Sampler (for training)
# =====================
from collections import defaultdict

class BalancedPKSampler(Sampler):
    def __init__(self, dataset, P=8, K=4, id_column="tiger_id"):
        self.dataset = dataset
        self.P, self.K = int(P), int(K)

        lbl2idxs = defaultdict(list)
        for i in range(len(dataset)):
            r = dataset.df.iloc[i]
            seq = dataset.id2idx[r[id_column]]
            lbl2idxs[seq].append(i)
        self.lbl2idxs = dict(lbl2idxs)
        self.labels = list(self.lbl2idxs.keys())
        self.batch_size = self.P*self.K
        total = sum(len(v) for v in self.lbl2idxs.values())
        self.num_batches = max(1, math.ceil(total/self.batch_size))

    def __len__(self): return self.num_batches

    def __iter__(self):
        labels = self.labels[:]
        for _ in range(self.num_batches):
            chosen = random.sample(labels, k=min(self.P, len(labels)))
            batch = []
            for lab in chosen:
                idxs = self.lbl2idxs[lab]
                if len(idxs) >= self.K:
                    batch.extend(random.sample(idxs, self.K))
                else:
                    batch.extend(random.choices(idxs, k=self.K))
            random.shuffle(batch)
            yield batch

In [10]:
# =====================
# Triplet (batch-hard) with cosine distance
# =====================
def pairwise_cosine_distance(emb):
    return 1.0 - (emb @ emb.t())

def batch_hard_triplet(labels, emb, margin=0.3):
    with torch.no_grad():
        labels = labels.view(-1,1)
        matches = (labels == labels.t())
        eye = torch.eye(matches.size(0), dtype=torch.bool, device=matches.device)
        pos_mask = matches & ~eye
        neg_mask = ~matches

    dist = pairwise_cosine_distance(emb)
    pos_d = dist.clone(); pos_d[~pos_mask] = -1.0
    neg_d = dist.clone(); neg_d[~neg_mask] =  2.0

    hardest_pos = pos_d.max(dim=1).values
    hardest_neg = neg_d.min(dim=1).values
    valid = hardest_pos >= 0.0
    if valid.sum() == 0:
        return emb.new_tensor(0.0, requires_grad=True)
    return F.relu(hardest_pos[valid] - hardest_neg[valid] + margin).mean()

In [11]:
# =====================
# DINOv2 backbone wrapper (TigerDINO)
# =====================
class TigerDINO(nn.Module):
    def __init__(self, dino_model="dinov2_vitb14", embed_dim=512, num_classes=None, pretrained=True):
        super().__init__()
        self.backbone, feat_dim = self._load_dinov2(dino_model, pretrained)
        self.embedding = nn.Linear(feat_dim, embed_dim)
        self.bn = nn.BatchNorm1d(embed_dim)
        self.classifier = nn.Linear(embed_dim, num_classes) if num_classes is not None else None

    def _load_dinov2(self, name, pretrained=True):
        try:
            model = torch.hub.load('facebookresearch/dinov2', name, trust_repo=True)
            feat_dim = getattr(model, "embed_dim", 768)
            return model, feat_dim
        except Exception as e:
            print(f"⚠️ Could not load DINOv2 via torch.hub ({e}). Using a small fallback convnet.")
            fallback = nn.Sequential(
                nn.Conv2d(3,64,3,2,1), nn.ReLU(),
                nn.Conv2d(64,128,3,2,1), nn.ReLU(),
                nn.AdaptiveAvgPool2d(1), nn.Flatten()
            )
            return fallback, 128

    def forward(self, x, return_features=False):
        h = self.backbone(x)
        if h.ndim > 2:  # fallback path
            h = torch.flatten(h, 1)
        z = self.bn(self.embedding(h))
        if return_features or self.classifier is None:
            return z
        logits = self.classifier(z)
        return logits, z

In [12]:
# =====================
# Build training datasets/loaders
# =====================
train_ds = TigerCSVReID(train_df, IMAGE_ROOT, transform=train_tf, id_map=IDMAP)
val_ds   = TigerCSVReID(val_df,   IMAGE_ROOT, transform=val_tf,   id_map=IDMAP)

try:
    train_loader = DataLoader(train_ds, batch_sampler=BalancedPKSampler(train_ds, P=P, K=K, id_column="tiger_id"),
                              num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Using BalancedPKSampler: P={P}, K={K}, batch={P*K}")
except Exception as e:
    print("BalancedPKSampler failed; falling back to simple batching:", e)
    train_loader = DataLoader(train_ds, batch_size=P*K, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True)

Using BalancedPKSampler: P=8, K=4, batch=32


In [16]:
# --- Build gallery/query FROM VAL (paper-style or deterministic) ---
def build_val_gallery_query(val_df, random_one_per_id=True, seed=42):
    """
    Paper style (default):
      - For each identity in VAL with >=2 images:
          pick 1 image at random for the GALLERY (seeded),
          put the remaining images in the QUERY.
    Deterministic alternative (random_one_per_id=False):
      - Use the first image (after sorting) as GALLERY, rest as QUERY.
    """
    # Keep only IDs that have at least 2 images in VAL
    counts = val_df["tiger_id"].value_counts()
    ids_keep = set(counts[counts >= 2].index)
    v = val_df[val_df["tiger_id"].isin(ids_keep)].copy()

    # Give every row a stable unique row id we can sample on
    v = v.reset_index(drop=False).rename(columns={"index": "orig_idx"})

    if random_one_per_id:
        # one RANDOM gallery per ID (seeded) -> rest are query
        gal_ix = (v.groupby("tiger_id")["orig_idx"]
                    .apply(lambda s: s.sample(n=1, random_state=seed))
                    .tolist())
    else:
        # deterministic: first-per-ID (after sorting) goes to gallery
        v = v.sort_values(["tiger_id", "image_filename"]).reset_index(drop=True)
        gal_ix = (v.groupby("tiger_id")["orig_idx"]
                    .first()
                    .tolist())

    mask = v["orig_idx"].isin(gal_ix)
    gal_df = v.loc[mask, ["tiger_id", "image_filename"]].reset_index(drop=True)
    qry_df = v.loc[~mask, ["tiger_id", "image_filename"]].reset_index(drop=True)
    return gal_df, qry_df
gal_df, qry_df = build_val_gallery_query(val_df, random_one_per_id=PAPER_RANDOM_GALLERY, seed=SEED)
print("Gallery size:", len(gal_df), "Query size:", len(qry_df), "| #IDs:", qry_df["tiger_id"].nunique())

gallery_ds = TigerCSVReID(gal_df, IMAGE_ROOT, transform=val_tf, id_map=IDMAP)
query_ds   = TigerCSVReID(qry_df, IMAGE_ROOT, transform=val_tf, id_map=IDMAP)

gallery_loader = DataLoader(gallery_ds, batch_size=BATCH_EVAL, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True)
query_loader   = DataLoader(query_ds,   batch_size=BATCH_EVAL, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True)


Gallery size: 21 Query size: 438 | #IDs: 21


In [17]:
# =====================
# Embedding & evaluation (Rank-1, mAP)
# =====================
@torch.no_grad()
def embed_loader(model, loader, device):
    model.eval()
    feats, ids = [], []
    for imgs, seq_ids, raw_ids, _ in loader:
        imgs = imgs.to(device, non_blocking=True)
        f = F.normalize(model(imgs, return_features=True), dim=1)
        feats.append(f.cpu())
        ids.append(seq_ids.cpu())
    return torch.cat(feats, 0), torch.cat(ids, 0)

@torch.no_grad()
def reid_eval(model, gallery_loader, query_loader, device):
    g_feats, g_ids = embed_loader(model, gallery_loader, device)
    q_feats, q_ids = embed_loader(model, query_loader, device)

    sims = q_feats @ g_feats.t()  # cosine

    # Rank-1
    nn_idx = sims.argmax(dim=1)
    nn_ids = g_ids[nn_idx]
    rank1 = (nn_ids == q_ids).float().mean().item() * 100.0

    # mAP
    ap_list = []
    gid_to_pos = {}
    for i, gid in enumerate(g_ids.tolist()):
        gid_to_pos.setdefault(gid, []).append(i)

    for q in range(q_feats.size(0)):
        gt = q_ids[q].item()
        order = torch.argsort(sims[q], descending=True)
        relevant = set(gid_to_pos.get(gt, []))
        if not relevant:
            continue
        hits = 0; precisions = []
        for rank, gidx in enumerate(order.tolist(), start=1):
            if gidx in relevant:
                hits += 1
                precisions.append(hits / rank)
        if precisions:
            ap_list.append(sum(precisions) / len(relevant))

    mAP = (sum(ap_list)/len(ap_list)*100.0) if ap_list else 0.0
    return rank1, mAP

In [None]:
# =====================
# Train loop + eval
# =====================
from pathlib import Path
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

model = TigerDINO(dino_model=DINO_MODEL, embed_dim=EMBED_DIM, num_classes=None, pretrained=True).to(device)

bb_params, head_params = [], []
for n,p in model.named_parameters():
    if not p.requires_grad: 
        continue
    (bb_params if "backbone" in n else head_params).append(p)

optimizer = torch.optim.AdamW([
    {"params": bb_params,  "lr": LR*0.1},
    {"params": head_params,"lr": LR},
], weight_decay=WD)

best_map = -1.0
for epoch in range(1, EPOCHS+1):
    model.train()
    running, steps = 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for imgs, labels, *_ in pbar:
        imgs   = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        feats = model(imgs, return_features=True)
        feats = F.normalize(feats, dim=1)
        loss  = batch_hard_triplet(labels, feats, margin=MARGIN)
        loss.backward()
        optimizer.step()

        running += float(loss.item()); steps += 1
        pbar.set_postfix(loss=running/max(1,steps))

    train_loss = running / max(1,steps)
    rank1, mAP = reid_eval(model, gallery_loader, query_loader, device=device)
    print(f"[Val] Epoch {epoch:02d} | loss={train_loss:.4f} | Rank-1={rank1:.2f}% | mAP={mAP:.2f}%")

    ckpt = {"model": model.state_dict(), "epoch": epoch, "mAP": mAP}
    torch.save(ckpt, os.path.join(SAVE_DIR, "last.pt"))
    if mAP > best_map:
        best_map = mAP
        torch.save(ckpt, os.path.join(SAVE_DIR, "best.pt"))
        print("💾 Saved new best checkpoint.")
print(f"Done. Best mAP = {best_map:.2f}%")

Using cache found in /home/sagemaker-user/.cache/torch/hub/facebookresearch_dinov2_main
Epoch 1/20: 100%|██████████| 45/45 [00:17<00:00,  2.63it/s, loss=0.123]


[Val] Epoch 01 | loss=0.1227 | Rank-1=82.88% | mAP=88.79%
💾 Saved new best checkpoint.


Epoch 2/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0531]


[Val] Epoch 02 | loss=0.0531 | Rank-1=86.07% | mAP=90.95%
💾 Saved new best checkpoint.


Epoch 3/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0452]


[Val] Epoch 03 | loss=0.0452 | Rank-1=87.21% | mAP=90.80%


Epoch 4/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0331]


[Val] Epoch 04 | loss=0.0331 | Rank-1=80.82% | mAP=87.65%


Epoch 5/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0321]


[Val] Epoch 05 | loss=0.0321 | Rank-1=85.16% | mAP=89.36%


Epoch 6/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0187]


[Val] Epoch 06 | loss=0.0187 | Rank-1=86.30% | mAP=89.92%


Epoch 7/20: 100%|██████████| 45/45 [00:16<00:00,  2.76it/s, loss=0.0148]


[Val] Epoch 07 | loss=0.0148 | Rank-1=85.39% | mAP=88.98%


Epoch 8/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0133] 


[Val] Epoch 08 | loss=0.0133 | Rank-1=86.53% | mAP=89.48%


Epoch 9/20: 100%|██████████| 45/45 [00:16<00:00,  2.76it/s, loss=0.0114] 


[Val] Epoch 09 | loss=0.0114 | Rank-1=86.76% | mAP=89.74%


Epoch 10/20: 100%|██████████| 45/45 [00:16<00:00,  2.78it/s, loss=0.0115]


[Val] Epoch 10 | loss=0.0115 | Rank-1=85.39% | mAP=89.61%


Epoch 11/20: 100%|██████████| 45/45 [00:16<00:00,  2.76it/s, loss=0.00879]


[Val] Epoch 11 | loss=0.0088 | Rank-1=85.16% | mAP=89.06%


Epoch 12/20: 100%|██████████| 45/45 [00:16<00:00,  2.79it/s, loss=0.00577]


[Val] Epoch 12 | loss=0.0058 | Rank-1=85.62% | mAP=89.96%


Epoch 13/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0143] 


[Val] Epoch 13 | loss=0.0143 | Rank-1=86.07% | mAP=90.14%


Epoch 14/20: 100%|██████████| 45/45 [00:16<00:00,  2.76it/s, loss=0.00955]


[Val] Epoch 14 | loss=0.0096 | Rank-1=85.39% | mAP=88.89%


Epoch 15/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0112] 


[Val] Epoch 15 | loss=0.0112 | Rank-1=82.88% | mAP=87.44%


Epoch 16/20: 100%|██████████| 45/45 [00:16<00:00,  2.77it/s, loss=0.0072] 


[Val] Epoch 16 | loss=0.0072 | Rank-1=85.84% | mAP=90.42%


Epoch 17/20: 100%|██████████| 45/45 [00:16<00:00,  2.76it/s, loss=0.00675]


[Val] Epoch 17 | loss=0.0068 | Rank-1=83.11% | mAP=87.85%


Epoch 18/20: 100%|██████████| 45/45 [00:16<00:00,  2.78it/s, loss=0.00863]


[Val] Epoch 18 | loss=0.0086 | Rank-1=81.96% | mAP=86.96%


Epoch 19/20: 100%|██████████| 45/45 [00:16<00:00,  2.78it/s, loss=0.00365]


[Val] Epoch 19 | loss=0.0036 | Rank-1=84.25% | mAP=88.85%


Epoch 20/20:  93%|█████████▎| 42/45 [00:15<00:01,  2.87it/s, loss=0.0091] 

In [None]:
# =====================
# Evaluate helper
# =====================
def evaluate_checkpoint(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    m = TigerDINO(dino_model=DINO_MODEL, embed_dim=EMBED_DIM, num_classes=None, pretrained=False).to(device)
    sd = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    m.load_state_dict(sd)
    r1, mp = reid_eval(m, gallery_loader, query_loader, device=device)
    print(f"Rank-1 = {r1:.2f}% | mAP = {mp:.2f}%")
    return r1, mp

# Example:
# evaluate_checkpoint(os.path.join(SAVE_DIR, "best.pt"))