# Jaguar Re-Identification

## Score: .874

In [1]:
import os
import math
import random
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [2]:
# =============================================================================
# CONFIG
# =============================================================================
class CFG:
    data_dir = Path('jaguar-re-id')
    train_csv = data_dir / 'train.csv'
    test_csv = data_dir / 'test.csv'
    train_dir = data_dir / 'train' / 'train'
    test_dir = data_dir / 'test' / 'test'
    
    backbone = 'convnext_base.fb_in22k_ft_in1k'
    image_size = 256
    num_classes = 31
    
    epochs = 25
    batch_size = 24
    lr = 2e-4
    weight_decay = 0.01
    warmup_epochs = 2
    
    arcface_s = 30.0
    arcface_m = 0.55
    label_smoothing = 0.1
    
    samples_per_class = 60
    
    use_tta = True
    use_qe = True
    use_rerank = True
    
    num_workers = 0
    mixed_precision = True

In [3]:
# =============================================================================
# DATA
# =============================================================================
train_df = pd.read_csv(CFG.train_csv)
test_df = pd.read_csv(CFG.test_csv)
print(f"Train: {len(train_df)} | Test pairs: {len(test_df)}")

Train: 1895 | Test pairs: 137270


In [4]:
# =============================================================================
# TRANSFORMS
# =============================================================================
def get_train_transforms():
    return A.Compose([
        A.LongestMaxSize(max_size=CFG.image_size),
        A.PadIfNeeded(CFG.image_size, CFG.image_size, border_mode=0),
        A.HorizontalFlip(p=0.5),
        A.Affine(scale=(0.9, 1.1), rotate=(-12, 12), shear=(-8, 8), p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.08, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

def get_test_transforms(flip=False):
    t = [
        A.LongestMaxSize(max_size=CFG.image_size),
        A.PadIfNeeded(CFG.image_size, CFG.image_size, border_mode=0),
    ]
    if flip:
        t.append(A.HorizontalFlip(p=1.0))
    t.extend([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    return A.Compose(t)

In [5]:
# =============================================================================
# DATASET & SAMPLER
# =============================================================================
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        unique_ids = sorted(df['ground_truth'].unique())
        self.label_map = {name: i for i, name in enumerate(unique_ids)}
        self.labels = [self.label_map[gt] for gt in df['ground_truth']]
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(self.img_dir / row['filename']).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, torch.tensor(self.labels[idx], dtype=torch.long)


class JaguarTestDataset(Dataset):
    def __init__(self, filenames, img_dir, transform):
        self.filenames = filenames
        self.img_dir = Path(img_dir)
        self.transform = transform
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img = np.array(Image.open(self.img_dir / fname).convert('RGB'))
        img = self.transform(image=img)['image']
        return img, fname


class BalancedSampler(Sampler):
    def __init__(self, labels, samples_per_class):
        self.labels = labels
        self.samples_per_class = samples_per_class
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[label].append(idx)
        self.num_classes = len(self.class_indices)
    
    def __iter__(self):
        indices = []
        for label in self.class_indices:
            class_idx = self.class_indices[label]
            if len(class_idx) >= self.samples_per_class:
                sampled = random.sample(class_idx, self.samples_per_class)
            else:
                sampled = random.choices(class_idx, k=self.samples_per_class)
            indices.extend(sampled)
        random.shuffle(indices)
        return iter(indices)
    
    def __len__(self):
        return self.num_classes * self.samples_per_class

In [6]:
# =============================================================================
# MODEL
# =============================================================================
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1.0 / self.p)


class ArcFaceLoss(nn.Module):
    def __init__(self, in_features, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, labels):
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        theta = torch.acos(cosine.clamp(-1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = F.one_hot(labels, num_classes=cosine.size(1)).float()
        output = cosine * (1 - one_hot) + target_logits * one_hot
        return output * self.s


class JaguarModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(CFG.backbone, pretrained=True, num_classes=0)
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.dropout = nn.Dropout(0.1)
        self.arcface = ArcFaceLoss(self.feat_dim, CFG.num_classes, CFG.arcface_s, CFG.arcface_m)
        print(f"Loaded {CFG.backbone} | Features: {self.feat_dim}")
    
    def extract(self, x):
        features = self.backbone.forward_features(x)
        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(math.sqrt(N))
            features = features.permute(0, 2, 1).reshape(B, C, H, W)
        emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        return emb
    
    def forward(self, x, labels=None):
        emb = self.extract(x)
        if labels is not None:
            emb = self.dropout(emb)
            return self.arcface(emb, labels)
        return emb

In [7]:
# =============================================================================
# POST-PROCESSING
# =============================================================================
def query_expansion(emb, top_k=3):
    print("Applying Query Expansion...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)


def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=0.3):
    print("Applying Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)
    
    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1+1]
        backward_k1 = initial_rank[forward_k1, :k1+1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])
    
    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union
    
    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

In [8]:
# =============================================================================
# TRAINING
# =============================================================================
train_dataset = JaguarDataset(train_df, CFG.train_dir, get_train_transforms())
train_loader = DataLoader(
    train_dataset,
    batch_size=CFG.batch_size,
    sampler=BalancedSampler(train_dataset.labels, CFG.samples_per_class),
    num_workers=CFG.num_workers,
    pin_memory=True
)

model = JaguarModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=CFG.lr, epochs=CFG.epochs, steps_per_epoch=len(train_loader),
    pct_start=CFG.warmup_epochs/CFG.epochs
)
scaler = torch.amp.GradScaler('cuda')
criterion = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)

best_loss = float('inf')
for epoch in range(CFG.epochs):
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
    
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            logits = model(imgs, labels)
            loss = criterion(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), 'best_model.pth')

print(f"Training complete | Best loss: {best_loss:.4f}")

Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


Epoch 1/25:   0%|          | 0/78 [00:00<?, ?it/s]



Epoch 1 | Loss: 16.4924 | LR: 1.05e-04


Epoch 2/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2 | Loss: 8.1711 | LR: 2.00e-04


Epoch 3/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3 | Loss: 2.9528 | LR: 1.99e-04


Epoch 4/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.6806 | LR: 1.96e-04


Epoch 5/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.2129 | LR: 1.92e-04


Epoch 6/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.0621 | LR: 1.85e-04


Epoch 7/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.8856 | LR: 1.77e-04


Epoch 8/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 8 | Loss: 0.8842 | LR: 1.68e-04


Epoch 9/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.8327 | LR: 1.58e-04


Epoch 10/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.7337 | LR: 1.46e-04


Epoch 11/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.7110 | LR: 1.33e-04


Epoch 12/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.7115 | LR: 1.20e-04


Epoch 13/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.7010 | LR: 1.07e-04


Epoch 14/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.7014 | LR: 9.30e-05


Epoch 15/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.6963 | LR: 7.95e-05


Epoch 16/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.6962 | LR: 6.63e-05


Epoch 17/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.6941 | LR: 5.38e-05


Epoch 18/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.6944 | LR: 4.22e-05


Epoch 19/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 19 | Loss: 0.6927 | LR: 3.16e-05


Epoch 20/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 20 | Loss: 0.6926 | LR: 2.23e-05


Epoch 21/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.6896 | LR: 1.45e-05


Epoch 22/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.6915 | LR: 8.21e-06


Epoch 23/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.6948 | LR: 3.66e-06


Epoch 24/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.6897 | LR: 9.09e-07


Epoch 25/25:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.6902 | LR: 9.53e-10
Training complete | Best loss: 0.6896


In [9]:
# =============================================================================
# OPTIONAL: SECOND RUN (ENSEMBLE) â€” run after first training, then run inference
# =============================================================================
seed_everything(123)
train_dataset_s2 = JaguarDataset(train_df, CFG.train_dir, get_train_transforms())
train_loader_s2 = DataLoader(
    train_dataset_s2,
    batch_size=CFG.batch_size,
    sampler=BalancedSampler(train_dataset_s2.labels, CFG.samples_per_class),
    num_workers=CFG.num_workers,
    pin_memory=True
)
model_s2 = JaguarModel().to(device)
optimizer_s2 = torch.optim.AdamW(model_s2.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler_s2 = torch.optim.lr_scheduler.OneCycleLR(
    optimizer_s2, max_lr=CFG.lr, epochs=CFG.epochs, steps_per_epoch=len(train_loader_s2),
    pct_start=CFG.warmup_epochs/CFG.epochs
)
scaler_s2 = torch.amp.GradScaler('cuda')
criterion_s2 = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)

best_loss_s2 = float('inf')
for epoch in range(CFG.epochs):
    model_s2.train()
    total_loss = 0
    pbar = tqdm(train_loader_s2, desc=f'Epoch {epoch+1}/{CFG.epochs} (s2)')
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_s2.zero_grad()
        with torch.amp.autocast('cuda'):
            logits = model_s2(imgs, labels)
            loss = criterion_s2(logits, labels)
        scaler_s2.scale(loss).backward()
        scaler_s2.step(optimizer_s2)
        scaler_s2.update()
        scheduler_s2.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    avg_loss = total_loss / len(train_loader_s2)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | LR: {scheduler_s2.get_last_lr()[0]:.2e}")
    if avg_loss < best_loss_s2:
        best_loss_s2 = avg_loss
        torch.save(model_s2.state_dict(), 'best_model_s2.pth')
print(f"Second run complete | Best loss: {best_loss_s2:.4f} | Saved best_model_s2.pth")

Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


Epoch 1/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 1 | Loss: 16.6338 | LR: 1.05e-04


Epoch 2/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2 | Loss: 8.4894 | LR: 2.00e-04


Epoch 3/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3 | Loss: 3.0059 | LR: 1.99e-04


Epoch 4/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4 | Loss: 1.6203 | LR: 1.96e-04


Epoch 5/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5 | Loss: 1.2080 | LR: 1.92e-04


Epoch 6/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.1243 | LR: 1.85e-04


Epoch 7/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 7 | Loss: 0.9754 | LR: 1.77e-04


Epoch 8/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 8 | Loss: 0.8443 | LR: 1.68e-04


Epoch 9/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.7788 | LR: 1.58e-04


Epoch 10/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.7470 | LR: 1.46e-04


Epoch 11/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 11 | Loss: 0.7465 | LR: 1.33e-04


Epoch 12/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 12 | Loss: 0.7130 | LR: 1.20e-04


Epoch 13/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 13 | Loss: 0.7116 | LR: 1.07e-04


Epoch 14/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 14 | Loss: 0.7013 | LR: 9.30e-05


Epoch 15/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 15 | Loss: 0.6999 | LR: 7.95e-05


Epoch 16/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 16 | Loss: 0.7078 | LR: 6.63e-05


Epoch 17/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 17 | Loss: 0.6987 | LR: 5.38e-05


Epoch 18/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 18 | Loss: 0.6957 | LR: 4.22e-05


Epoch 19/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 19 | Loss: 0.6927 | LR: 3.16e-05


Epoch 20/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 20 | Loss: 0.6940 | LR: 2.23e-05


Epoch 21/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 21 | Loss: 0.6932 | LR: 1.45e-05


Epoch 22/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 22 | Loss: 0.6912 | LR: 8.21e-06


Epoch 23/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 23 | Loss: 0.6914 | LR: 3.66e-06


Epoch 24/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 24 | Loss: 0.6907 | LR: 9.09e-07


Epoch 25/25 (s2):   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 25 | Loss: 0.6910 | LR: 9.53e-10
Second run complete | Best loss: 0.6907 | Saved best_model_s2.pth


In [10]:
# =============================================================================
# INFERENCE
# =============================================================================
def extract_embeddings(transform, m):
    loader = DataLoader(
        JaguarTestDataset(unique_images, CFG.test_dir, transform),
        batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers
    )
    feats, names = [], []
    with torch.no_grad():
        for imgs, fnames in tqdm(loader, leave=False):
            emb = m(imgs.to(device))
            feats.append(F.normalize(emb, dim=1).cpu())
            names.extend(fnames)
    return torch.cat(feats, dim=0), names

def get_embeddings(m):
    e1, names = extract_embeddings(get_test_transforms(flip=False), m)
    if CFG.use_tta:
        e2, _ = extract_embeddings(get_test_transforms(flip=True), m)
        e = F.normalize((e1 + e2) / 2, dim=1)
    else:
        e = e1
    return e.numpy(), names

unique_images = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
print(f"Extracting embeddings for {len(unique_images)} images...")

model.load_state_dict(torch.load('best_model.pth'))
model.eval()
emb, names = get_embeddings(model)

if Path('best_model_s2.pth').exists():
    print("Ensemble: loading second model...")
    model2 = JaguarModel().to(device)
    model2.load_state_dict(torch.load('best_model_s2.pth'))
    model2.eval()
    emb2, _ = get_embeddings(model2)
    emb = (emb + emb2) / 2
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)

img_map = {n: i for i, n in enumerate(names)}

Extracting embeddings for 371 images...


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Ensemble: loading second model...
Loaded convnext_base.fb_in22k_ft_in1k | Features: 1024


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [11]:
# =============================================================================
# POST-PROCESSING & SUBMISSION
# =============================================================================
if CFG.use_qe:
    emb = query_expansion(emb)

sim_matrix = emb @ emb.T

if CFG.use_rerank:
    sim_matrix = k_reciprocal_rerank(sim_matrix)

print("Computing similarities...")
preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    s = sim_matrix[img_map[row['query_image']], img_map[row['gallery_image']]]
    preds.append(float(np.clip(s, 0, 1)))

submission = pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds})
submission.to_csv('submission.csv', index=False)
print(f"Saved submission.csv | Mean sim: {np.mean(preds):.4f}")

Applying Query Expansion...
Applying Re-ranking...
Computing similarities...


  0%|          | 0/137270 [00:00<?, ?it/s]

Saved submission.csv | Mean sim: 0.3063
