# Jaguar Re-Identification

## Score: .888

In [1]:
import os
import math
import random
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [2]:
# =============================================================================
# CONFIG
# =============================================================================
BASE_DIR = Path.cwd()  
MODELS_DIR = BASE_DIR.parent / 'MODELS'
CONVNEXT_DIR = MODELS_DIR / 'convnext-tensorflow2-large-21k-1k-384-fe-v1'
EVA02_DIR = MODELS_DIR / 'eva02-pytorch-default-v1'

class CFG:
    # data
    data_dir = BASE_DIR / 'jaguar-re-id'
    train_csv = data_dir / 'train.csv'
    test_csv = data_dir / 'test.csv'
    train_dir = data_dir / 'train' / 'train'
    test_dir = data_dir / 'test' / 'test'

    # model backbones / weights roots (not all are used at once)
    models_dir = MODELS_DIR
    convnext_dir = CONVNEXT_DIR
    eva02_dir = EVA02_DIR

    backbone = 'convnext_large.fb_in22k_ft_in1k'
    image_size = 224
    num_classes = 31

    epochs = 10
    batch_size = 8
    grad_accum_steps = 4
    lr = 2e-4
    weight_decay = 0.01
    warmup_epochs = 1
    
    arcface_s = 30.0
    arcface_m = 0.5
    arcface_subcenters = 1
    label_smoothing = 0.1
    
    samples_per_class = 60
    val_split_seed = 42
    early_stop_patience = 3

    use_tta = True
    use_multiscale_tta = False
    multiscale_sizes = (224,)
    use_qe = True
    qe_top_k = 3
    use_rerank = True
    rerank_lambda = 0.3
    
    train_seeds = [42, 420, 666]

    use_supcon = False
    use_triplet = False
    triplet_margin = 0.2
    supcon_tau = 0.07
    pk_p = 8
    pk_k = 4

    num_workers = 0
    mixed_precision = True

In [3]:
# =============================================================================
# DATA
# =============================================================================
full_train = pd.read_csv(CFG.train_csv)
test_df = pd.read_csv(CFG.test_csv)
rng = random.Random(getattr(CFG, 'val_split_seed', 42))
val_indices = []
for gt, grp in full_train.groupby('ground_truth', sort=True):
    idx = grp.index.tolist()
    if len(idx) >= 2:
        val_indices.extend(rng.sample(idx, 2))
train_df = full_train.drop(index=val_indices).reset_index(drop=True)
val_df = full_train.loc[val_indices].reset_index(drop=True)
print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test pairs: {len(test_df)}")

Train: 1833 | Val: 62 | Test pairs: 137270


In [4]:
# =============================================================================
# TRANSFORMS
# =============================================================================
NORM_MEAN, NORM_STD = ((0.481, 0.457, 0.408), (0.268, 0.261, 0.275)) if 'eva' in CFG.backbone else ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

def get_train_transforms():
    return A.Compose([
        A.LongestMaxSize(max_size=CFG.image_size),
        A.PadIfNeeded(CFG.image_size, CFG.image_size, border_mode=0),
        A.HorizontalFlip(p=0.5),
        A.Affine(scale=(0.9, 1.1), rotate=(-12, 12), shear=(-8, 8), p=0.5),
        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.6),
        A.CoarseDropout(num_holes_range=(4, 12), hole_height_range=(16, 48), hole_width_range=(16, 48), p=0.3),
        A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ToTensorV2(),
    ])

def get_test_transforms(flip=False, size=None):
    sz = size if size is not None else CFG.image_size
    t = [
        A.LongestMaxSize(max_size=sz),
        A.PadIfNeeded(sz, sz, border_mode=0),
    ]
    if flip:
        t.append(A.HorizontalFlip(p=1.0))
    t.extend([
        A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ToTensorV2(),
    ])
    return A.Compose(t)

In [5]:
# =============================================================================
# DATASET & SAMPLER
# =============================================================================
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        unique_ids = sorted(df['ground_truth'].unique())
        self.label_map = {name: i for i, name in enumerate(unique_ids)}
        self.labels = [self.label_map[gt] for gt in df['ground_truth']]
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(self.img_dir / row['filename']).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, torch.tensor(self.labels[idx], dtype=torch.long)


class JaguarTestDataset(Dataset):
    def __init__(self, filenames, img_dir, transform):
        self.filenames = filenames
        self.img_dir = Path(img_dir)
        self.transform = transform
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img = np.array(Image.open(self.img_dir / fname).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, fname


class BalancedSampler(Sampler):
    def __init__(self, labels, samples_per_class):
        self.labels = labels
        self.samples_per_class = samples_per_class
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[label].append(idx)
        self.num_classes = len(self.class_indices)
    
    def __iter__(self):
        indices = []
        for label in self.class_indices:
            class_idx = self.class_indices[label]
            if len(class_idx) >= self.samples_per_class:
                sampled = random.sample(class_idx, self.samples_per_class)
            else:
                sampled = random.choices(class_idx, k=self.samples_per_class)
            indices.extend(sampled)
        random.shuffle(indices)
        return iter(indices)
    
    def __len__(self):
        return self.num_classes * self.samples_per_class


class PKSampler(Sampler):
    def __init__(self, labels, p, k):
        self.labels = np.asarray(labels)
        self.p, self.k = p, k
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[label].append(idx)
        self.classes = list(self.class_indices.keys())
        n_samples = sum(len(v) for v in self.class_indices.values())
        self.num_batches = max(1, n_samples // (p * k))

    def __iter__(self):
        for _ in range(self.num_batches):
            batch_classes = random.sample(self.classes, min(self.p, len(self.classes)))
            indices = []
            for c in batch_classes:
                idx = self.class_indices[c]
                if len(idx) >= self.k:
                    indices.extend(random.sample(idx, self.k))
                else:
                    indices.extend(random.choices(idx, k=self.k))
            random.shuffle(indices)
            yield indices

    def __len__(self):
        return self.num_batches

In [6]:
# =============================================================================
# MODEL
# =============================================================================
def supcon_loss(emb, labels, tau=0.07):
    emb = F.normalize(emb.float(), dim=1)
    sim = torch.mm(emb, emb.t()) / tau
    B = emb.size(0)
    eye = torch.eye(B, device=emb.device, dtype=torch.bool)
    mask_same = (labels.unsqueeze(0) == labels.unsqueeze(1)) & ~eye
    large_neg = -1e4
    log_denom = torch.logsumexp(sim.masked_fill(eye, large_neg), dim=1)
    log_num = torch.logsumexp(sim.masked_fill(~mask_same, large_neg), dim=1)
    valid = mask_same.sum(1) > 0
    if valid.sum() == 0:
        return sim.sum() * 0
    return (log_denom[valid] - log_num[valid]).mean()


def triplet_loss(emb, labels, margin=0.2):
    emb = F.normalize(emb.float(), dim=1)
    sim = torch.mm(emb, emb.t())
    B = emb.size(0)
    eye = torch.eye(B, device=emb.device, dtype=torch.bool)
    mask_pos = (labels.unsqueeze(0) == labels.unsqueeze(1)) & ~eye
    mask_neg = labels.unsqueeze(0) != labels.unsqueeze(1)
    sim_pos = sim.masked_fill(~mask_pos, -2.0)
    sim_neg = sim.masked_fill(~mask_neg, -2.0)
    sim_pos_max = sim_pos.max(1)[0]
    sim_neg_max = sim_neg.max(1)[0]
    valid = mask_pos.sum(1) > 0
    if valid.sum() == 0:
        return emb.sum() * 0
    loss = (margin - sim_pos_max[valid] + sim_neg_max[valid]).clamp(min=0).mean()
    return loss


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1.0 / self.p)


class ArcFaceLoss(nn.Module):
    def __init__(self, in_features, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, labels):
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        theta = torch.acos(cosine.clamp(-1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=cosine.size(1)).float()
        output = cosine * (1 - one_hot) + target_logits * one_hot
        return output * self.s


class SubCenterArcFaceLoss(nn.Module):
    def __init__(self, in_features, num_classes, K=3, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m, self.K = s, m, K
        self.weight = nn.Parameter(torch.FloatTensor(num_classes * K, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, labels):
        x_n = F.normalize(x)
        w_n = F.normalize(self.weight)
        cosine = F.linear(x_n, w_n)
        B, NK = cosine.shape
        num_classes = NK // self.K
        cosine = cosine.view(B, num_classes, self.K)
        cosine_max, _ = cosine.max(dim=2)
        theta_target = torch.acos(cosine_max[range(B), labels].clamp(-1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta_target + self.m)
        one_hot = F.one_hot(labels, num_classes=num_classes).float()
        output = cosine_max * (1 - one_hot) + target_logits.unsqueeze(1) * one_hot
        return output * self.s


class JaguarModel(nn.Module):
    def __init__(self):
        super().__init__()
        kwargs = {'pretrained': True, 'num_classes': 0}
        if 'vit' in CFG.backbone or 'eva' in CFG.backbone:
            kwargs['img_size'] = CFG.image_size
        self.backbone = timm.create_model(CFG.backbone, **kwargs)
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.dropout = nn.Dropout(0.1)
        K = getattr(CFG, 'arcface_subcenters', 1)
        if K > 1:
            self.arcface = SubCenterArcFaceLoss(self.feat_dim, CFG.num_classes, K=K, s=CFG.arcface_s, m=CFG.arcface_m)
        else:
            self.arcface = ArcFaceLoss(self.feat_dim, CFG.num_classes, CFG.arcface_s, CFG.arcface_m)
        print(f"Loaded {CFG.backbone} | Features: {self.feat_dim}")

    def extract(self, x):
        features = self.backbone.forward_features(x)
        if features.dim() == 3:
            emb = features[:, 0]
        else:
            emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        return emb

    def forward(self, x, labels=None):
        emb = self.extract(x)
        if labels is not None:
            emb = self.dropout(emb)
            return self.arcface(emb, labels)
        return emb

In [7]:
# =============================================================================
# POST-PROCESSING
# =============================================================================
def query_expansion(emb, top_k=None, verbose=True):
    top_k = top_k if top_k is not None else getattr(CFG, 'qe_top_k', 3)
    if verbose:
        print("Applying Query Expansion...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / (np.linalg.norm(new_emb, axis=1, keepdims=True) + 1e-8)


def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=None, verbose=True):
    lambda_value = lambda_value if lambda_value is not None else getattr(CFG, 'rerank_lambda', 0.3)
    if verbose:
        print("Applying Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)
    
    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1+1]
        backward_k1 = initial_rank[forward_k1, :k1+1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])
    
    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union
    
    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

In [8]:
# =============================================================================
# VALIDATION
# =============================================================================
def compute_val_mAP(emb, labels):
    emb = np.asarray(emb)
    labels = np.asarray(labels)
    n = len(labels)
    sim = emb @ emb.T
    mAP_per_id = []
    for c in np.unique(labels):
        idx = np.where(labels == c)[0]
        if len(idx) < 2:
            continue
        aps = []
        for q in idx:
            gallery = np.array([i for i in range(n) if i != q])
            rel = (labels[gallery] == c).astype(float)
            if rel.sum() == 0:
                continue
            order = np.argsort(-sim[q, gallery])
            rel_ord = rel[order]
            prec = np.cumsum(rel_ord) / (1 + np.arange(len(rel_ord)))
            ap = (prec[rel_ord == 1].sum()) / rel.sum()
            aps.append(ap)
        if aps:
            mAP_per_id.append(np.mean(aps))
    return float(np.mean(mAP_per_id)) if mAP_per_id else 0.0

def compute_val_mAP_from_sim(sim, labels):
    labels = np.asarray(labels)
    n = len(labels)
    mAP_per_id = []
    for c in np.unique(labels):
        idx = np.where(labels == c)[0]
        if len(idx) < 2:
            continue
        aps = []
        for q in idx:
            gallery = np.array([i for i in range(n) if i != q])
            rel = (labels[gallery] == c).astype(float)
            if rel.sum() == 0:
                continue
            order = np.argsort(-sim[q, gallery])
            rel_ord = rel[order]
            prec = np.cumsum(rel_ord) / (1 + np.arange(len(rel_ord)))
            ap = (prec[rel_ord == 1].sum()) / rel.sum()
            aps.append(ap)
        if aps:
            mAP_per_id.append(np.mean(aps))
    return float(np.mean(mAP_per_id)) if mAP_per_id else 0.0

In [9]:
# =============================================================================
# TRAINING
# =============================================================================
train_dataset = JaguarDataset(train_df, CFG.train_dir, get_train_transforms())
val_dataset = JaguarDataset(val_df, CFG.train_dir, get_test_transforms(flip=False))

for seed in CFG.train_seeds:
    seed_everything(seed)
    if getattr(CFG, 'use_supcon', False) or getattr(CFG, 'use_triplet', False):
        train_loader = DataLoader(
            train_dataset,
            batch_sampler=PKSampler(train_dataset.labels, CFG.pk_p, CFG.pk_k),
            num_workers=CFG.num_workers,
            pin_memory=True
        )
    else:
        train_loader = DataLoader(
            train_dataset,
            batch_size=CFG.batch_size,
            sampler=BalancedSampler(train_dataset.labels, CFG.samples_per_class),
            num_workers=CFG.num_workers,
            pin_memory=True
        )
    val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)
    grad_accum = getattr(CFG, 'grad_accum_steps', 1)

    model = JaguarModel().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)
    scaler = torch.amp.GradScaler('cuda')
    criterion = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)

    best_val_mAP = 0.0
    patience = getattr(CFG, 'early_stop_patience', 5)
    no_improve = 0
    print(f"--- Seed {seed} ---")

    for epoch in range(CFG.epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
        optimizer.zero_grad()
        for step, (imgs, labels) in enumerate(pbar):
            imgs, labels = imgs.to(device), labels.to(device)
            with torch.amp.autocast('cuda'):
                if getattr(CFG, 'use_supcon', False):
                    emb = model(imgs)
                    loss = supcon_loss(emb, labels, CFG.supcon_tau)
                elif getattr(CFG, 'use_triplet', False):
                    emb = model(imgs)
                    loss = triplet_loss(emb, labels, getattr(CFG, 'triplet_margin', 0.2))
                else:
                    logits = model(imgs, labels)
                    loss = criterion(logits, labels)
                loss = loss / grad_accum
            scaler.scale(loss).backward()
            if (step + 1) % grad_accum == 0 or (step + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            total_loss += loss.item() * grad_accum
            pbar.set_postfix({'loss': f'{loss.item() * grad_accum:.4f}'})
        avg_loss = total_loss / len(train_loader)

        model.eval()
        emb_list, label_list = [], []
        with torch.no_grad():
            for imgs, labels in tqdm(val_loader, desc='Val', leave=False):
                emb = model(imgs.to(device))
                emb_list.append(F.normalize(emb, dim=1).cpu().numpy())
                label_list.append(labels.numpy())
        emb_val = np.concatenate(emb_list)
        labels_val = np.concatenate(label_list)
        val_mAP = compute_val_mAP(emb_val, labels_val)

        scheduler.step()
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Val mAP: {val_mAP:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
        if val_mAP > best_val_mAP:
            best_val_mAP = val_mAP
            no_improve = 0
            torch.save(model.state_dict(), f'best_model_seed{seed}.pth')
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stop at epoch {epoch+1} (no val mAP improvement for {patience} epochs)")
                break

    print(f"Seed {seed} done | Best val mAP: {best_val_mAP:.4f}")

print("Training complete | All seeds done.")

Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536
--- Seed 42 ---


Epoch 1/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1 | Loss: 12.2524 | Val mAP: 0.6803 | LR: 1.95e-04


Epoch 2/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2 | Loss: 4.2159 | Val mAP: 0.8598 | LR: 1.81e-04


Epoch 3/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3 | Loss: 1.8879 | Val mAP: 0.8685 | LR: 1.59e-04


Epoch 4/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.3717 | Val mAP: 0.9162 | LR: 1.31e-04


Epoch 5/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.1575 | Val mAP: 0.8981 | LR: 1.00e-04


Epoch 6/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6 | Loss: 0.9822 | Val mAP: 0.9137 | LR: 6.91e-05


Epoch 7/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.8620 | Val mAP: 0.9096 | LR: 4.12e-05
Early stop at epoch 7 (no val mAP improvement for 3 epochs)
Seed 42 done | Best val mAP: 0.9162
Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536
--- Seed 420 ---


Epoch 1/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1 | Loss: 12.2844 | Val mAP: 0.6016 | LR: 1.95e-04


Epoch 2/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2 | Loss: 4.3386 | Val mAP: 0.8716 | LR: 1.81e-04


Epoch 3/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3 | Loss: 2.0171 | Val mAP: 0.9075 | LR: 1.59e-04


Epoch 4/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.5210 | Val mAP: 0.9203 | LR: 1.31e-04


Epoch 5/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.1237 | Val mAP: 0.9083 | LR: 1.00e-04


Epoch 6/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6 | Loss: 0.9450 | Val mAP: 0.8949 | LR: 6.91e-05


Epoch 7/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.8569 | Val mAP: 0.8958 | LR: 4.12e-05
Early stop at epoch 7 (no val mAP improvement for 3 epochs)
Seed 420 done | Best val mAP: 0.9203
Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536
--- Seed 666 ---


Epoch 1/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1 | Loss: 12.2935 | Val mAP: 0.7276 | LR: 1.95e-04


Epoch 2/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2 | Loss: 4.4516 | Val mAP: 0.8728 | LR: 1.81e-04


Epoch 3/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3 | Loss: 2.0119 | Val mAP: 0.9327 | LR: 1.59e-04


Epoch 4/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.4196 | Val mAP: 0.9233 | LR: 1.31e-04


Epoch 5/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.1643 | Val mAP: 0.9308 | LR: 1.00e-04


Epoch 6/10:   0%|          | 0/233 [00:00<?, ?it/s]

Val:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6 | Loss: 0.9463 | Val mAP: 0.9199 | LR: 6.91e-05
Early stop at epoch 6 (no val mAP improvement for 3 epochs)
Seed 666 done | Best val mAP: 0.9327
Training complete | All seeds done.


In [10]:
# =============================================================================
# INFERENCE
# =============================================================================
unique_images = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
print(f"Extracting embeddings for {len(unique_images)} images...")

def extract_embeddings(transform, m):
    loader = DataLoader(JaguarTestDataset(unique_images, CFG.test_dir, transform), batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)
    feats, names = [], []
    with torch.no_grad():
        for imgs, fnames in tqdm(loader, leave=False):
            emb = m(imgs.to(device))
            feats.append(F.normalize(emb, dim=1).cpu())
            names.extend(fnames)
    return torch.cat(feats, dim=0), names

def get_embeddings(m):
    if getattr(CFG, 'use_multiscale_tta', False):
        sizes = getattr(CFG, 'multiscale_sizes', (224, 256))
        parts = []
        for sz in sizes:
            t = get_test_transforms(flip=False, size=sz)
            e1, names = extract_embeddings(t, m)
            if CFG.use_tta:
                t2 = get_test_transforms(flip=True, size=sz)
                e2, _ = extract_embeddings(t2, m)
                e1 = F.normalize((e1 + e2) / 2, dim=1)
            parts.append(e1.numpy())
        e = np.mean(parts, axis=0).astype(np.float32)
        e = e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8)
        return e, names
    e1, names = extract_embeddings(get_test_transforms(flip=False), m)
    if CFG.use_tta:
        e2, _ = extract_embeddings(get_test_transforms(flip=True), m)
        e = F.normalize((e1 + e2) / 2, dim=1)
    else:
        e = e1
    return e.numpy(), names

emb_list = []
for seed in CFG.train_seeds:
    ckpt_path = f'best_model_seed{seed}.pth'
    if not Path(ckpt_path).exists():
        raise FileNotFoundError(f"Run training first; missing {ckpt_path}")
    model = JaguarModel().to(device)
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()
    e, names = get_embeddings(model)
    emb_list.append(e)
emb = np.mean(emb_list, axis=0).astype(np.float32)
emb = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-8)
img_map = {n: i for i, n in enumerate(names)}

Extracting embeddings for 371 images...
Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

Loaded convnext_large.fb_in22k_ft_in1k | Features: 1536


  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

In [11]:
# =============================================================================
# SUBMISSION
# =============================================================================
if CFG.use_qe:
    emb = query_expansion(emb)
sim_matrix = emb @ emb.T
if CFG.use_rerank:
    sim_matrix = k_reciprocal_rerank(sim_matrix)
preds = [float(np.clip(sim_matrix[img_map[row['query_image']], img_map[row['gallery_image']]], 0, 1)) for _, row in test_df.iterrows()]
pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds}).to_csv('submission.csv', index=False)
print(f"Saved submission.csv | Mean sim: {np.mean(preds):.4f}")

Applying Query Expansion...
Applying Re-ranking...
Saved submission.csv | Mean sim: 0.3253
