In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchvision
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split as tts
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
root_dir = './animal-clef-2025'
metadata_dir = './animal-clef-2025/metadata.csv'

df = pd.read_csv(metadata_dir)
database_df = df[df["split"] == "database"].dropna(subset=['identity'])
database_df

Unnamed: 0,image_id,identity,path,date,orientation,species,split,dataset
0,0,LynxID2025_lynx_37,images/LynxID2025/database/000f9ee1aad063a4485...,,right,lynx,database,LynxID2025
1,1,LynxID2025_lynx_37,images/LynxID2025/database/0020edb6689e9f78462...,,left,lynx,database,LynxID2025
2,2,LynxID2025_lynx_49,images/LynxID2025/database/003152e4145b5b69400...,,left,lynx,database,LynxID2025
4,4,LynxID2025_lynx_13,images/LynxID2025/database/003c3f82011e9c3f849...,,right,lynx,database,LynxID2025
6,6,LynxID2025_lynx_07,images/LynxID2025/database/0051adb5bd1b63867b9...,,left,lynx,database,LynxID2025
...,...,...,...,...,...,...,...,...
14704,14704,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022
14705,14705,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022
14706,14706,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022
14707,14707,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,left,loggerhead turtle,database,SeaTurtleID2022


In [4]:
identity_counts = df[df['split'] == 'database']['identity'].value_counts().reset_index()
identity_counts.columns = ['identity', 'num_images']
identity_counts

Unnamed: 0,identity,num_images
0,LynxID2025_lynx_49,353
1,LynxID2025_lynx_62,289
2,LynxID2025_lynx_32,256
3,LynxID2025_lynx_43,234
4,SeaTurtleID2022_t243,190
...,...,...
1097,SalamanderID2025_64,1
1098,SalamanderID2025_65,1
1099,SalamanderID2025_646,1
1100,SalamanderID2025_647,1


In [5]:
identity_counts[identity_counts['num_images'] == 1].value_counts()

identity              num_images
LynxID2025_lynx_08    1             1
SalamanderID2025_552  1             1
SalamanderID2025_559  1             1
SalamanderID2025_558  1             1
SalamanderID2025_557  1             1
                                   ..
SalamanderID2025_344  1             1
SalamanderID2025_341  1             1
SalamanderID2025_334  1             1
SalamanderID2025_331  1             1
SeaTurtleID2022_t103  1             1
Name: count, Length: 317, dtype: int64

In [6]:
encoder = LabelEncoder()
database_df['label'] = encoder.fit_transform(database_df['identity'])
database_df

Unnamed: 0,image_id,identity,path,date,orientation,species,split,dataset,label
0,0,LynxID2025_lynx_37,images/LynxID2025/database/000f9ee1aad063a4485...,,right,lynx,database,LynxID2025,29
1,1,LynxID2025_lynx_37,images/LynxID2025/database/0020edb6689e9f78462...,,left,lynx,database,LynxID2025,29
2,2,LynxID2025_lynx_49,images/LynxID2025/database/003152e4145b5b69400...,,left,lynx,database,LynxID2025,40
4,4,LynxID2025_lynx_13,images/LynxID2025/database/003c3f82011e9c3f849...,,right,lynx,database,LynxID2025,11
6,6,LynxID2025_lynx_07,images/LynxID2025/database/0051adb5bd1b63867b9...,,left,lynx,database,LynxID2025,7
...,...,...,...,...,...,...,...,...,...
14704,14704,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022,1101
14705,14705,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022,1101
14706,14706,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,top,loggerhead turtle,database,SeaTurtleID2022,1101
14707,14707,SeaTurtleID2022_t610,images/SeaTurtleID2022/database/turtles-data/d...,2022-07-08,left,loggerhead turtle,database,SeaTurtleID2022,1101


In [7]:
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'path']
        # print(img_path)
        img_path = os.path.join(root_dir, img_path)
        label = self.df.loc[idx, 'label']
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


# ==== Transformations ====
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [13]:
train_dataset = CustomDataset(database_df, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# val_dataset = CustomDataset(val_df, transform=transform)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

# ==== Model ====
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(encoder.classes_))
model = model.to(device)

In [14]:
from tqdm import tqdm
import torch

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
EPOCHS =10

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}]", leave=False)
    
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device).long()
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        avg_loss = running_loss / (pbar.n + 1)
        pbar.set_postfix(loss=avg_loss)
    
    print(f"✅ Epoch {epoch+1}/{EPOCHS} finished. Average Loss: {running_loss/len(train_loader):.4f}")

                                                                                                                       

✅ Epoch 1/10 finished. Average Loss: 5.0518


                                                                                                                       

✅ Epoch 2/10 finished. Average Loss: 3.5289


                                                                                                                       

✅ Epoch 3/10 finished. Average Loss: 2.5919


                                                                                                                       

✅ Epoch 4/10 finished. Average Loss: 1.8760


                                                                                                                       

✅ Epoch 5/10 finished. Average Loss: 1.3577


                                                                                                                       

✅ Epoch 6/10 finished. Average Loss: 1.0025


                                                                                                                       

✅ Epoch 7/10 finished. Average Loss: 0.7880


                                                                                                                       

✅ Epoch 8/10 finished. Average Loss: 0.6267


                                                                                                                       

✅ Epoch 9/10 finished. Average Loss: 0.4882


                                                                                                                       

✅ Epoch 10/10 finished. Average Loss: 0.3407




In [16]:
import joblib

torch.save(model.state_dict(), 'reid_classifier.pth')
joblib.dump(encoder, 'label_encoder.pkl')
print("✅ Model and label encoder saved.")

✅ Model and label encoder saved.


In [17]:
df = pd.read_csv(metadata_dir)
query_df = df[df['split'] == 'query'].reset_index(drop=True)
query_df

Unnamed: 0,image_id,identity,path,date,orientation,species,split,dataset
0,3,,images/LynxID2025/query/003b89301c7b9f6d18f722...,,back,lynx,query,LynxID2025
1,5,,images/LynxID2025/query/004d500301a70ec9b5ba08...,,left,lynx,query,LynxID2025
2,12,,images/LynxID2025/query/00d97c67f0cb0d13a3a449...,,left,lynx,query,LynxID2025
3,13,,images/LynxID2025/query/00dcbabf03826937bcf6a0...,,right,lynx,query,LynxID2025
4,18,,images/LynxID2025/query/011d81e0402d1be66bccab...,,right,lynx,query,LynxID2025
...,...,...,...,...,...,...,...,...
2130,15204,,images/SeaTurtleID2022/query/images/fecd2dfed0...,2024-06-07,,loggerhead turtle,query,SeaTurtleID2022
2131,15205,,images/SeaTurtleID2022/query/images/ff1a0c812b...,2023-06-28,,loggerhead turtle,query,SeaTurtleID2022
2132,15206,,images/SeaTurtleID2022/query/images/ff22f1cfa6...,2024-06-09,,loggerhead turtle,query,SeaTurtleID2022
2133,15207,,images/SeaTurtleID2022/query/images/ff5d5116d1...,2023-06-21,,loggerhead turtle,query,SeaTurtleID2022


In [18]:
# ==== Load model and encoder ====
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features,  len(joblib.load("label_encoder.pkl").classes_))
model.load_state_dict(torch.load("reid_classifier.pth"))
model = model.to(device)
model.eval()

encoder = joblib.load("label_encoder.pkl")



In [19]:
threshold = 0.5
results = []

for i, row in query_df.iterrows():
    img_path = os.path.join(root_dir, row['path'])
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = F.softmax(outputs, dim=1)
        max_prob, pred_class = torch.max(probs, 1)

    if max_prob.item() >= threshold:
        identity = encoder.inverse_transform([pred_class.item()])[0]
    else:
        identity = "new_individual"

    results.append({
        "image_id": row["image_id"],
        "identity": identity
    })

# === Save to CSV ===
submission_df = pd.DataFrame(results)
submission_df.to_csv("submission.csv", index=False)
print("Saved predictions to sample_submission.csv")

Saved predictions to sample_submission.csv
