In [1]:
import os
import math
import random
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm

# Config

In [3]:
class Config:
    seed = 42
    model_name = "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"

    img_size = 448
    embedding_dim = 1024
    num_classes = 31

    num_epochs = 10
    batch_size = 4
    grad_accum = 4

    lr = 2e-5
    weight_decay = 1e-3

    arcface_s = 30.0
    arcface_m = 0.50

    use_tta = True
    use_qe = True
    use_rerank = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_type = "cuda" if torch.cuda.is_available() else "cpu"


def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(Config.seed)

# Dataset

In [4]:
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False):
        self.df = df
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.is_test = is_test
        if not is_test:
            unique_ids = sorted(df["ground_truth"].unique())
            self.label_map = {name: i for i, name in enumerate(unique_ids)}
            self.df["label"] = self.df["ground_truth"].map(self.label_map)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row["filename"]
        img_path = self.img_dir / img_name
        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.new("RGB", (Config.img_size, Config.img_size))

        if self.transform:
            img = self.transform(img)
        if self.is_test:
            return img, img_name
        return img, torch.tensor(row["label"], dtype=torch.long)

# Transforms

In [5]:
train_transform = transforms.Compose(
    [
        transforms.Resize((Config.img_size, Config.img_size)),  # Áªü‰∏ÄÂ∞∫ÂØ∏
        transforms.RandomHorizontalFlip(),    # ÈöèÊú∫Ê∞¥Âπ≥ÁøªËΩ¨
        transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)), # ÈöèÊú∫‰ªøÂ∞ÑÂèòÊç¢
        transforms.ColorJitter(brightness=0.2, contrast=0.2), # ÈöèÊú∫È¢úËâ≤ÊäñÂä®
        transforms.ToTensor(), # ËΩ¨Êç¢‰∏∫Âº†Èáè
        transforms.Normalize([0.481, 0.457, 0.408], [0.268, 0.261, 0.275]), # Ê†áÂáÜÂåñ
        transforms.RandomErasing(p=0.25), # ÈöèÊú∫Êì¶Èô§
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((Config.img_size, Config.img_size)), # Áªü‰∏ÄÂ∞∫ÂØ∏
        transforms.ToTensor(), # ËΩ¨Êç¢‰∏∫Âº†Èáè
        transforms.Normalize([0.481, 0.457, 0.408], [0.268, 0.261, 0.275]), # Ê†áÂáÜÂåñ
    ]
)

# Model

In [6]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return F.avg_pool2d(
            x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))
        ).pow(1.0 / self.p)


class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label=None):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        if label is None:
            return cosine
        phi = cosine - self.m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        return output * self.s


class EVABoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
            Config.model_name, pretrained=True, num_classes=0
        )
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.head = ArcFaceLayer(
            self.feat_dim, Config.num_classes, s=Config.arcface_s, m=Config.arcface_m
        )

    def forward(self, x, label=None):
        features = self.backbone.forward_features(x)
        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(math.sqrt(N))
            if H * W != N:
                features = features[:, -H * W :, :]
            features = features.permute(0, 2, 1).reshape(B, C, H, W)

        emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        if label is not None:
            return self.head(emb, label)
        return emb

# Utils

In [7]:
def train_epoch(model, loader, optimizer, criterion, scaler):
    model.train()
    loss_meter = 0
    for i, (imgs, labels) in enumerate(tqdm(loader, leave=False)):
        imgs, labels = imgs.to(Config.device), labels.to(Config.device)

        with torch.amp.autocast(Config.device_type):
            loss = criterion(model(imgs, labels), labels)
            loss = loss / Config.grad_accum

        scaler.scale(loss).backward()

        if (i + 1) % Config.grad_accum == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        loss_meter += loss.item() * Config.grad_accum
    return loss_meter / len(loader)


@torch.no_grad()
def extract_features(model, loader):
    model.eval()
    feats, names = [], []
    for imgs, fnames in tqdm(loader, desc="Inference"):
        imgs = imgs.to(Config.device)
        f1 = model(imgs)
        if Config.use_tta:
            f2 = model(torch.flip(imgs, [3]))
            f1 = (f1 + f2) / 2
        feats.append(F.normalize(f1, dim=1).cpu())
        names.extend(fnames)
    return torch.cat(feats, dim=0).numpy(), names


def query_expansion(emb, top_k=3):
    print("Applying QE...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)


def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=0.3):
    print("Applying Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)
    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, : k1 + 1]
        backward_k1 = initial_rank[forward_k1, : k1 + 1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])
    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [
            inv for inv in ind_non_zero if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0
        ]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union
    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

# Train

In [8]:
TRAIN_CSV = "jaguar-re-id/train.csv"
TEST_CSV = "jaguar-re-id/test.csv"
TRAIN_DIR = "jaguar-re-id/train/train"
TEST_DIR = "jaguar-re-id/test/test"

train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)

train_loader = DataLoader(
    JaguarDataset(train_df, TRAIN_DIR, train_transform),
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=False,
)
model = EVABoss().to(Config.device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay
)
scaler = torch.amp.GradScaler(Config.device_type)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=Config.num_epochs
)

print(f"üî• Training EVA-02 Large (448px)...")

for epoch in range(Config.num_epochs):
    loss = train_epoch(model, train_loader, optimizer, nn.CrossEntropyLoss(), scaler)
    scheduler.step()
    print(
        f"Epoch {epoch+1}/{Config.num_epochs} | Loss: {loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}"
    )

unique_test = sorted(set(test_df["query_image"]) | set(test_df["gallery_image"]))
test_loader = DataLoader(
    JaguarDataset(
        pd.DataFrame({"filename": unique_test}), TEST_DIR, test_transform, True
    ),
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=False,
)

emb, names = extract_features(model, test_loader)
img_map = {n: i for i, n in enumerate(names)}

if Config.use_qe:
    emb = query_expansion(emb)
sim_matrix = emb @ emb.T
if Config.use_rerank:
    sim_matrix = k_reciprocal_rerank(sim_matrix)

preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Mapping"):
    s = sim_matrix[img_map[row["query_image"]], img_map[row["gallery_image"]]]
    preds.append(max(0.0, min(1.0, s)))

sub = pd.DataFrame({"row_id": test_df["row_id"], "similarity": preds})
sub.to_csv("submission.csv", index=False)
print(f"‚úÖ Done! Mean Sim: {np.mean(preds):.4f}")

KeyboardInterrupt: 

In [9]:
# ËÆ°ÁÆóÂèÇÊï∞ÈáèÂáΩÊï∞
def count_parameters(model, trainable_only=True):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

# ÂàùÂßãÂåñÊ®°Âûã
model = EVABoss().to(Config.device)

# ÂàÜÂà´ËÆ°ÁÆóÂêÑÈÉ®ÂàÜÂèÇÊï∞Èáè
backbone_params = count_parameters(model.backbone)
gem_params = count_parameters(model.gem)
bn_params = count_parameters(model.bn)
head_params = count_parameters(model.head)
total_params = count_parameters(model)

# ÊâìÂç∞ÁªìÊûúÔºàÂçï‰ΩçËΩ¨Êç¢ÔºöM=1e6ÔºåK=1e3Ôºâ
print(f"‰∏ªÂπ≤ÁΩëÁªúÂèÇÊï∞Èáè: {backbone_params / 1e6:.2f} M")
print(f"GeMÊ±†ÂåñÂèÇÊï∞Èáè: {gem_params}")
print(f"BatchNorm1dÂèÇÊï∞Èáè: {bn_params}")
print(f"ArcFaceÂ§¥ÂèÇÊï∞Èáè: {head_params / 1e3:.2f} K")
print(f"Ê®°ÂûãÊÄªÂèÇÊï∞Èáè: {total_params / 1e6:.2f} M")

LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.