# Jaguar Re-Identification

## Score: .859

In [8]:
import os
import math
import random
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import timm

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [9]:
# =============================================================================
# CONFIG
# =============================================================================
class CFG:
    data_dir = Path('jaguar-re-id')
    train_csv = data_dir / 'train.csv'
    test_csv = data_dir / 'test.csv'
    train_dir = data_dir / 'train' / 'train'
    test_dir = data_dir / 'test' / 'test'
    
    backbone = 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k'
    image_size = 448
    num_classes = 31
    
    epochs = 10
    batch_size = 4
    grad_accum = 4
    lr = 2e-5
    weight_decay = 1e-3
    
    arcface_s = 30.0
    arcface_m = 0.5
    
    use_tta = True
    use_qe = True
    use_rerank = True
    
    num_workers = 0

In [10]:
# =============================================================================
# DATA
# =============================================================================
train_df = pd.read_csv(CFG.train_csv)
test_df = pd.read_csv(CFG.test_csv)
print(f"Train: {len(train_df)} | Test pairs: {len(test_df)}")

Train: 1895 | Test pairs: 137270


In [11]:
# =============================================================================
# TRANSFORMS
# =============================================================================
train_transform = T.Compose([
    T.Resize((CFG.image_size, CFG.image_size)),
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.481, 0.457, 0.408], [0.268, 0.261, 0.275]),
    T.RandomErasing(p=0.25),
])

test_transform = T.Compose([
    T.Resize((CFG.image_size, CFG.image_size)),
    T.ToTensor(),
    T.Normalize([0.481, 0.457, 0.408], [0.268, 0.261, 0.275]),
])

In [12]:
# =============================================================================
# DATASET
# =============================================================================
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.is_test = is_test
        if not is_test:
            unique_ids = sorted(df['ground_truth'].unique())
            self.label_map = {name: i for i, name in enumerate(unique_ids)}
            self.df['label'] = self.df['ground_truth'].map(self.label_map)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(self.img_dir / row['filename']).convert('RGB')
        img = self.transform(img)
        if self.is_test:
            return img, row['filename']
        return img, torch.tensor(row['label'], dtype=torch.long)

In [13]:
# =============================================================================
# MODEL
# =============================================================================
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1.0 / self.p)


class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, x, label=None):
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        if label is None:
            return cosine
        phi = cosine - self.m
        one_hot = torch.zeros_like(cosine).scatter_(1, label.view(-1, 1), 1)
        return ((one_hot * phi) + ((1.0 - one_hot) * cosine)) * self.s


class JaguarModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(CFG.backbone, pretrained=True, num_classes=0)
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.head = ArcFaceLayer(self.feat_dim, CFG.num_classes, CFG.arcface_s, CFG.arcface_m)
        print(f"Loaded {CFG.backbone} | Feat: {self.feat_dim}")
    
    def forward(self, x, label=None):
        features = self.backbone.forward_features(x)
        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(math.sqrt(N))
            if H * W != N:
                features = features[:, -H*W:, :]
            features = features.permute(0, 2, 1).reshape(B, C, H, W)
        emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        if label is not None:
            return self.head(emb, label)
        return emb

In [14]:
# =============================================================================
# POST-PROCESSING
# =============================================================================
def query_expansion(emb, top_k=3):
    print("Applying Query Expansion...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)


def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=0.3):
    print("Applying Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)
    
    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1+1]
        backward_k1 = initial_rank[forward_k1, :k1+1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])
    
    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union
    
    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

In [15]:
# =============================================================================
# TRAINING
# =============================================================================
train_loader = DataLoader(
    JaguarDataset(train_df, CFG.train_dir, train_transform),
    batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True
)

model = JaguarModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)
scaler = torch.amp.GradScaler('cuda')
criterion = nn.CrossEntropyLoss()

for epoch in range(CFG.epochs):
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
    
    for i, (imgs, labels) in enumerate(pbar):
        imgs, labels = imgs.to(device), labels.to(device)
        
        with torch.amp.autocast('cuda'):
            loss = criterion(model(imgs, labels), labels) / CFG.grad_accum
        
        scaler.scale(loss).backward()
        
        if (i + 1) % CFG.grad_accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item() * CFG.grad_accum
        pbar.set_postfix({'loss': f'{loss.item() * CFG.grad_accum:.4f}'})
    
    scheduler.step()
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

torch.save(model.state_dict(), 'best_model.pth')
print("Training complete")

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/348M [00:00<?, ?B/s]

Loaded eva02_base_patch14_448.mim_in22k_ft_in22k_in1k | Feat: 768


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Epoch 1/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 1 | Loss: 13.7642 | LR: 1.95e-05


Epoch 2/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 2 | Loss: 6.9714 | LR: 1.81e-05


Epoch 3/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 3 | Loss: 4.0637 | LR: 1.59e-05


Epoch 4/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 4 | Loss: 2.7159 | LR: 1.31e-05


Epoch 5/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 5 | Loss: 2.0453 | LR: 1.00e-05


Epoch 6/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 6 | Loss: 1.5801 | LR: 6.91e-06


Epoch 7/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 7 | Loss: 1.2582 | LR: 4.12e-06


Epoch 8/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 8 | Loss: 1.1076 | LR: 1.91e-06


Epoch 9/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 9 | Loss: 0.9474 | LR: 4.89e-07


Epoch 10/10:   0%|          | 0/474 [00:00<?, ?it/s]

Epoch 10 | Loss: 0.8589 | LR: 0.00e+00
Training complete


In [16]:
# =============================================================================
# INFERENCE
# =============================================================================
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

unique_images = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
test_loader = DataLoader(
    JaguarDataset(pd.DataFrame({'filename': unique_images}), CFG.test_dir, test_transform, is_test=True),
    batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers
)

print(f"Extracting embeddings for {len(unique_images)} images...")
feats, names = [], []
with torch.no_grad():
    for imgs, fnames in tqdm(test_loader):
        imgs = imgs.to(device)
        f1 = model(imgs)
        if CFG.use_tta:
            f2 = model(torch.flip(imgs, [3]))
            f1 = (f1 + f2) / 2
        feats.append(F.normalize(f1, dim=1).cpu())
        names.extend(fnames)

emb = torch.cat(feats, dim=0).numpy()
img_map = {n: i for i, n in enumerate(names)}

if CFG.use_qe:
    emb = query_expansion(emb)

sim_matrix = emb @ emb.T

if CFG.use_rerank:
    sim_matrix = k_reciprocal_rerank(sim_matrix)

Extracting embeddings for 371 images...


  0%|          | 0/93 [00:00<?, ?it/s]

Applying Query Expansion...
Applying Re-ranking...


In [17]:
# =============================================================================
# SUBMISSION
# =============================================================================
print("Computing similarities...")
preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    s = sim_matrix[img_map[row['query_image']], img_map[row['gallery_image']]]
    preds.append(max(0.0, min(1.0, s)))

submission = pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds})
submission.to_csv('submission.csv', index=False)
print(f"Saved | Mean: {np.mean(preds):.4f}")

Computing similarities...


  0%|          | 0/137270 [00:00<?, ?it/s]

Saved | Mean: 0.3411
