# Jaguar Re-Identification

## Score: .856

In [12]:
import os
import math
import random
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [13]:
# =============================================================================
# CONFIG
# =============================================================================
class CFG:
    data_dir = Path('jaguar-re-id')
    train_csv = data_dir / 'train.csv'
    test_csv = data_dir / 'test.csv'
    train_dir = data_dir / 'train' / 'train'
    test_dir = data_dir / 'test' / 'test'
    
    backbone = 'convnext_base.fb_in22k_ft_in1k'
    image_size = 256
    num_classes = 31
    
    epochs = 25
    batch_size = 24
    lr = 2e-4
    weight_decay = 0.01
    warmup_epochs = 2
    
    arcface_s = 30.0
    arcface_m = 0.6
    label_smoothing = 0.1
    
    samples_per_class = 60
    
    use_tta = True
    use_multiscale_tta = True
    multiscale_sizes = [256, 224]
    use_qe = True
    qe_top_k = 5
    use_rerank = True
    rerank_lambda = 0.5
    
    train_seeds = [42, 123]
    
    num_workers = 0
    mixed_precision = True

In [14]:
# =============================================================================
# DATA
# =============================================================================
train_df = pd.read_csv(CFG.train_csv)
test_df = pd.read_csv(CFG.test_csv)
print(f"Train: {len(train_df)} | Test pairs: {len(test_df)}")

Train: 1895 | Test pairs: 137270


In [15]:
# =============================================================================
# TRANSFORMS
# =============================================================================
def get_train_transforms():
    return A.Compose([
        A.LongestMaxSize(max_size=CFG.image_size),
        A.PadIfNeeded(CFG.image_size, CFG.image_size, border_mode=0),
        A.HorizontalFlip(p=0.5),
        A.Affine(scale=(0.9, 1.1), rotate=(-12, 12), shear=(-8, 8), p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.08, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_test_transforms(flip=False, size=None):
    sz = size if size is not None else CFG.image_size
    t = [
        A.LongestMaxSize(max_size=sz),
        A.PadIfNeeded(sz, sz, border_mode=0),
    ]
    if flip:
        t.append(A.HorizontalFlip(p=1.0))
    t.extend([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    return A.Compose(t)

In [16]:
# =============================================================================
# DATASET & SAMPLER
# =============================================================================
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        unique_ids = sorted(df['ground_truth'].unique())
        self.label_map = {name: i for i, name in enumerate(unique_ids)}
        self.labels = [self.label_map[gt] for gt in df['ground_truth']]
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(self.img_dir / row['filename']).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, torch.tensor(self.labels[idx], dtype=torch.long)


class JaguarTestDataset(Dataset):
    def __init__(self, filenames, img_dir, transform):
        self.filenames = filenames
        self.img_dir = Path(img_dir)
        self.transform = transform
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img = np.array(Image.open(self.img_dir / fname).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, fname


class BalancedSampler(Sampler):
    def __init__(self, labels, samples_per_class):
        self.labels = labels
        self.samples_per_class = samples_per_class
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[label].append(idx)
        self.num_classes = len(self.class_indices)
    
    def __iter__(self):
        indices = []
        for label in self.class_indices:
            class_idx = self.class_indices[label]
            if len(class_idx) >= self.samples_per_class:
                sampled = random.sample(class_idx, self.samples_per_class)
            else:
                sampled = random.choices(class_idx, k=self.samples_per_class)
            indices.extend(sampled)
        random.shuffle(indices)
        return iter(indices)
    
    def __len__(self):
        return self.num_classes * self.samples_per_class

In [17]:
# =============================================================================
# MODEL
# =============================================================================
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1.0 / self.p)


class ArcFaceLoss(nn.Module):
    def __init__(self, in_features, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, labels):
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        theta = torch.acos(cosine.clamp(-1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=cosine.size(1)).float()
        output = cosine * (1 - one_hot) + target_logits * one_hot
        return output * self.s


class JaguarModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(CFG.backbone, pretrained=True, num_classes=0)
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.dropout = nn.Dropout(0.1)
        self.arcface = ArcFaceLoss(self.feat_dim, CFG.num_classes, CFG.arcface_s, CFG.arcface_m)
        print(f"Loaded {CFG.backbone} | Features: {self.feat_dim}")
    
    def extract(self, x):
        features = self.backbone.forward_features(x)
        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(math.sqrt(N))
            features = features.permute(0, 2, 1).reshape(B, C, H, W)
        emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        return emb
    
    def forward(self, x, labels=None):
        emb = self.extract(x)
        if labels is not None:
            emb = self.dropout(emb)
            return self.arcface(emb, labels)
        return emb

In [18]:
# =============================================================================
# POST-PROCESSING
# =============================================================================
def query_expansion(emb, top_k=None):
    top_k = top_k if top_k is not None else getattr(CFG, 'qe_top_k', 3)
    print("Applying Query Expansion...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)


def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=None):
    lambda_value = lambda_value if lambda_value is not None else getattr(CFG, 'rerank_lambda', 0.3)
    print("Applying Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)
    
    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1+1]
        backward_k1 = initial_rank[forward_k1, :k1+1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])
    
    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union
    
    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

In [19]:
# =============================================================================
# TRAINING (two seeds by default for ensemble)
# =============================================================================
train_seeds = getattr(CFG, 'train_seeds', [42])
for run_idx, seed in enumerate(train_seeds):
    seed_everything(seed)
    ckpt_name = 'best_model.pth' if run_idx == 0 else 'best_model_s2.pth'
    train_dataset = JaguarDataset(train_df, CFG.train_dir, get_train_transforms())
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.batch_size,
        sampler=BalancedSampler(train_dataset.labels, CFG.samples_per_class),
        num_workers=CFG.num_workers,
        pin_memory=True
    )
    model = JaguarModel().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=CFG.lr, epochs=CFG.epochs, steps_per_epoch=len(train_loader),
        pct_start=CFG.warmup_epochs/CFG.epochs
    )
    scaler = torch.amp.GradScaler('cuda')
    criterion = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)
    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f'Seed {seed} Epoch {epoch+1}/{CFG.epochs}')
        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):
                logits = model(imgs, labels)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), ckpt_name)
    print(f"Run {run_idx+1} (seed {seed}) complete | Best loss: {best_loss:.4f} | Saved {ckpt_name}")

Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


Seed 42 Epoch 1/25:   0%|          | 0/78 [00:00<?, ?it/s]



Epoch 1 | Loss: 17.6624 | LR: 1.05e-04


Seed 42 Epoch 2/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2 | Loss: 9.1943 | LR: 2.00e-04


Seed 42 Epoch 3/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3 | Loss: 3.2598 | LR: 1.99e-04


Seed 42 Epoch 4/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.7778 | LR: 1.96e-04


Seed 42 Epoch 5/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.2756 | LR: 1.92e-04


Seed 42 Epoch 6/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.1354 | LR: 1.85e-04


Seed 42 Epoch 7/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.9427 | LR: 1.77e-04


Seed 42 Epoch 8/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 8 | Loss: 0.8585 | LR: 1.68e-04


Seed 42 Epoch 9/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.8313 | LR: 1.58e-04


Seed 42 Epoch 10/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.8843 | LR: 1.46e-04


Seed 42 Epoch 11/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.7560 | LR: 1.33e-04


Seed 42 Epoch 12/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.7347 | LR: 1.20e-04


Seed 42 Epoch 13/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.7241 | LR: 1.07e-04


Seed 42 Epoch 14/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.7121 | LR: 9.30e-05


Seed 42 Epoch 15/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.7031 | LR: 7.95e-05


Seed 42 Epoch 16/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.6988 | LR: 6.63e-05


Seed 42 Epoch 17/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.6961 | LR: 5.38e-05


Seed 42 Epoch 18/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.6971 | LR: 4.22e-05


Seed 42 Epoch 19/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 19 | Loss: 0.6958 | LR: 3.16e-05


Seed 42 Epoch 20/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 20 | Loss: 0.6943 | LR: 2.23e-05


Seed 42 Epoch 21/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.6921 | LR: 1.45e-05


Seed 42 Epoch 22/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.6938 | LR: 8.21e-06


Seed 42 Epoch 23/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.6952 | LR: 3.66e-06


Seed 42 Epoch 24/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.6921 | LR: 9.09e-07


Seed 42 Epoch 25/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.6923 | LR: 9.53e-10
Run 1 (seed 42) complete | Best loss: 0.6921 | Saved best_model.pth
Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


Seed 123 Epoch 1/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 1 | Loss: 17.8149 | LR: 1.05e-04


Seed 123 Epoch 2/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2 | Loss: 9.5434 | LR: 2.00e-04


Seed 123 Epoch 3/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3 | Loss: 3.4748 | LR: 1.99e-04


Seed 123 Epoch 4/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.7893 | LR: 1.96e-04


Seed 123 Epoch 5/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.3307 | LR: 1.92e-04


Seed 123 Epoch 6/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.1158 | LR: 1.85e-04


Seed 123 Epoch 7/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.9222 | LR: 1.77e-04


Seed 123 Epoch 8/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 8 | Loss: 0.8910 | LR: 1.68e-04


Seed 123 Epoch 9/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.8616 | LR: 1.58e-04


Seed 123 Epoch 10/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.7923 | LR: 1.46e-04


Seed 123 Epoch 11/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.8069 | LR: 1.33e-04


Seed 123 Epoch 12/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.7400 | LR: 1.20e-04


Seed 123 Epoch 13/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.7154 | LR: 1.07e-04


Seed 123 Epoch 14/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.7193 | LR: 9.30e-05


Seed 123 Epoch 15/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.7186 | LR: 7.95e-05


Seed 123 Epoch 16/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.7034 | LR: 6.63e-05


Seed 123 Epoch 17/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.7015 | LR: 5.38e-05


Seed 123 Epoch 18/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.6978 | LR: 4.22e-05


Seed 123 Epoch 19/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 19 | Loss: 0.6946 | LR: 3.16e-05


Seed 123 Epoch 20/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 20 | Loss: 0.6967 | LR: 2.23e-05


Seed 123 Epoch 21/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.6965 | LR: 1.45e-05


Seed 123 Epoch 22/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.6944 | LR: 8.21e-06


Seed 123 Epoch 23/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.6935 | LR: 3.66e-06


Seed 123 Epoch 24/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.6950 | LR: 9.09e-07


Seed 123 Epoch 25/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.6948 | LR: 9.53e-10
Run 2 (seed 123) complete | Best loss: 0.6935 | Saved best_model_s2.pth


In [20]:
# =============================================================================
# OPTIONAL: SECOND RUN (ENSEMBLE) â€” run after first training, then run inference
# =============================================================================
seed_everything(123)
train_dataset_s2 = JaguarDataset(train_df, CFG.train_dir, get_train_transforms())
train_loader_s2 = DataLoader(
    train_dataset_s2,
    batch_size=CFG.batch_size,
    sampler=BalancedSampler(train_dataset_s2.labels, CFG.samples_per_class),
    num_workers=CFG.num_workers,
    pin_memory=True
)
model_s2 = JaguarModel().to(device)
optimizer_s2 = torch.optim.AdamW(model_s2.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler_s2 = torch.optim.lr_scheduler.OneCycleLR(
    optimizer_s2, max_lr=CFG.lr, epochs=CFG.epochs, steps_per_epoch=len(train_loader_s2),
    pct_start=CFG.warmup_epochs/CFG.epochs
)
scaler_s2 = torch.amp.GradScaler('cuda')
criterion_s2 = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)

best_loss_s2 = float('inf')
for epoch in range(CFG.epochs):
    model_s2.train()
    total_loss = 0
    pbar = tqdm(train_loader_s2, desc=f'Epoch {epoch+1}/{CFG.epochs} (s2)')
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_s2.zero_grad()
        with torch.amp.autocast('cuda'):
            logits = model_s2(imgs, labels)
            loss = criterion_s2(logits, labels)
        scaler_s2.scale(loss).backward()
        scaler_s2.step(optimizer_s2)
        scaler_s2.update()
        scheduler_s2.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    avg_loss = total_loss / len(train_loader_s2)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | LR: {scheduler_s2.get_last_lr()[0]:.2e}")
    if avg_loss < best_loss_s2:
        best_loss_s2 = avg_loss
        torch.save(model_s2.state_dict(), 'best_model_s2.pth')
print(f"Second run complete | Best loss: {best_loss_s2:.4f} | Saved best_model_s2.pth")

Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


Epoch 1/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 1 | Loss: 17.8443 | LR: 1.05e-04


Epoch 2/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2 | Loss: 9.6172 | LR: 2.00e-04


Epoch 3/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3 | Loss: 3.4170 | LR: 1.99e-04


Epoch 4/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.7908 | LR: 1.96e-04


Epoch 5/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.3033 | LR: 1.92e-04


Epoch 6/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.1238 | LR: 1.85e-04


Epoch 7/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.9523 | LR: 1.77e-04


Epoch 8/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 8 | Loss: 0.8396 | LR: 1.68e-04


Epoch 9/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.8073 | LR: 1.58e-04


Epoch 10/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.7340 | LR: 1.46e-04


Epoch 11/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.7489 | LR: 1.33e-04


Epoch 12/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.7139 | LR: 1.20e-04


Epoch 13/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.7111 | LR: 1.07e-04


Epoch 14/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.7027 | LR: 9.30e-05


Epoch 15/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.7012 | LR: 7.95e-05


Epoch 16/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.6975 | LR: 6.63e-05


Epoch 17/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.6989 | LR: 5.38e-05


Epoch 18/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.6964 | LR: 4.22e-05


Epoch 19/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 19 | Loss: 0.6936 | LR: 3.16e-05


Epoch 20/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 20 | Loss: 0.6942 | LR: 2.23e-05


Epoch 21/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.6951 | LR: 1.45e-05


Epoch 22/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.6931 | LR: 8.21e-06


Epoch 23/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.6930 | LR: 3.66e-06


Epoch 24/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.6926 | LR: 9.09e-07


Epoch 25/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.6928 | LR: 9.53e-10
Second run complete | Best loss: 0.6926 | Saved best_model_s2.pth


In [21]:
# =============================================================================
# INFERENCE
# =============================================================================
unique_images = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
print(f"Extracting embeddings for {len(unique_images)} images...")

def extract_embeddings(transform, m):
    loader = DataLoader(
        JaguarTestDataset(unique_images, CFG.test_dir, transform),
        batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers
    )
    feats, names = [], []
    with torch.no_grad():
        for imgs, fnames in tqdm(loader, leave=False):
            emb = m(imgs.to(device))
            feats.append(F.normalize(emb, dim=1).cpu())
            names.extend(fnames)
    return torch.cat(feats, dim=0), names

def get_embeddings(m, size=None):
    sz = size if size is not None else CFG.image_size
    e1, names = extract_embeddings(get_test_transforms(flip=False, size=sz), m)
    if CFG.use_tta:
        e2, _ = extract_embeddings(get_test_transforms(flip=True, size=sz), m)
        e = F.normalize((e1 + e2) / 2, dim=1)
    else:
        e = e1
    return e.numpy(), names

model.load_state_dict(torch.load('best_model.pth'))
model.eval()

if getattr(CFG, 'use_multiscale_tta', False) and getattr(CFG, 'multiscale_sizes', None):
    embs = []
    for sz in CFG.multiscale_sizes:
        e, names = get_embeddings(model, size=sz)
        embs.append(e)
    emb = np.stack(embs).mean(axis=0)
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
else:
    emb, names = get_embeddings(model)

if Path('best_model_s2.pth').exists():
    print("Ensemble: loading second model...")
    model2 = JaguarModel().to(device)
    model2.load_state_dict(torch.load('best_model_s2.pth'))
    model2.eval()
    if getattr(CFG, 'use_multiscale_tta', False) and getattr(CFG, 'multiscale_sizes', None):
        embs2 = [get_embeddings(model2, size=sz)[0] for sz in CFG.multiscale_sizes]
        emb2 = np.stack(embs2).mean(axis=0)
        emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
    else:
        emb2, _ = get_embeddings(model2)
    emb = (emb + emb2) / 2
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)

img_map = {n: i for i, n in enumerate(names)}

Extracting embeddings for 371 images...


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Ensemble: loading second model...
Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [22]:
# =============================================================================
# POST-PROCESSING & SUBMISSION
# =============================================================================
if CFG.use_qe:
    emb = query_expansion(emb)

sim_matrix = emb @ emb.T

if CFG.use_rerank:
    sim_matrix = k_reciprocal_rerank(sim_matrix)

print("Computing similarities...")
preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    s = sim_matrix[img_map[row['query_image']], img_map[row['gallery_image']]]
    preds.append(float(np.clip(s, 0, 1)))

submission = pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds})
submission.to_csv('submission.csv', index=False)
print(f"Saved submission.csv | Mean sim: {np.mean(preds):.4f}")

Applying Query Expansion...
Applying Re-ranking...
Computing similarities...


  0%|          | 0/137270 [00:00<?, ?it/s]

Saved submission.csv | Mean sim: 0.5040
