In [1]:
!pip install kaggle



In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import numpy as np
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import random
import kagglehub

In [3]:
# Download latest version
path = kagglehub.dataset_download("wildlifedatasets/seaturtleid2022")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/wildlifedatasets/seaturtleid2022?dataset_version_number=4...


100%|██████████| 1.64G/1.64G [00:25<00:00, 70.5MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/wildlifedatasets/seaturtleid2022/versions/4


In [4]:
root_dir = os.path.join(path, "turtles-data", "data")
data_csv = os.path.join(path, "turtles-data", "data", "metadata_splits.csv")

# Load the split file into a DataFrame
df = pd.read_csv(data_csv)
df.head()

Unnamed: 0,id,width,height,file_name,timestamp,identity,date,year,split_closed,split_closed_random,split_open,clarity
0,1,2000,1333,images/t001/CAluWEgwPX.JPG,2014-07-14 14:49:45,t001,2014-07-14,2014,test,test,train,3
1,2,2000,1333,images/t001/EKyrFKHQzh.JPG,2014-07-14 14:48:49,t001,2014-07-14,2014,test,train,train,2
2,3,2000,1333,images/t001/ELAvEqeXxT.JPG,2014-07-14 14:49:48,t001,2014-07-14,2014,test,train,train,2
3,4,2000,1124,images/t001/IxRLFwTGCv.JPG,2010-07-02 14:09:40,t001,2010-07-02,2010,train,test,train,3
4,5,2000,1333,images/t001/LKCJAhfLBJ.JPG,2014-07-14 14:48:28,t001,2014-07-14,2014,test,test,train,4


In [5]:
annotations_path = os.path.join(path, "turtles-data", "data", "annotations.json")

coco = COCO(annotations_path)

loading annotations into memory...
Done (t=6.05s)
creating index...
index created!


In [6]:
class ArcFace(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.s = s
        self.m = m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        theta = torch.acos(torch.clamp(cosine, -1.0, 1.0))
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        output = self.s * (cosine * (1 - one_hot) + torch.cos(theta + self.m) * one_hot)
        return output

In [7]:
class TurtleReIDSegmentsDataset(Dataset):
    def __init__(self, df, root_dir, coco, transform=None):
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.coco = coco
        self.transform = transform
        self.identity_to_idx = {id_: i for i, id_ in enumerate(sorted(df['identity'].unique()))}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.root_dir, row['file_name'])
        image = Image.open(img_path).convert("RGB")

        image_id = int(row['id'])
        cat_ids = self.coco.getCatIds()
        ann_ids = self.coco.getAnnIds(imgIds=image_id, catIds=cat_ids, iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)

        masks = {}
        for ann in anns:
            cat_name = self.coco.loadCats(ann['category_id'])[0]['name']
            if cat_name not in ['head', 'flipper', 'turtle']:
                continue
            submask = self.coco.annToMask(ann)
            masks[cat_name] = submask

        # extract segment crops
        crops = {}
        for part in ['head', 'flipper', 'turtle']:
            if part in masks:
                y, x = np.where(masks[part] > 0)
                if len(y) > 0 and len(x) > 0:
                    ymin, ymax, xmin, xmax = y.min(), y.max(), x.min(), x.max()
                    crops[part] = image.crop((xmin, ymin, xmax, ymax))
                else:
                    crops[part] = image
            else:
                crops[part] = image

        if self.transform:
            for part in crops:
                crops[part] = self.transform(crops[part])

        label = self.identity_to_idx[row['identity']]
        return crops['head'], crops['turtle'], crops['flipper'], torch.tensor(label, dtype=torch.long)



In [8]:
class TurtleReIDModel(nn.Module):
    def __init__(self, embedding_dim=512, num_classes=100):
        super().__init__()

        def get_backbone():
            backbone = models.swin_b(weights=models.Swin_B_Weights.IMAGENET1K_V1)
            backbone.head = nn.Linear(backbone.head.in_features, embedding_dim)
            return backbone

        self.head_net = get_backbone()
        self.shell_net = get_backbone()
        self.flipper_net = get_backbone()
        self.arcface = ArcFace(in_features=embedding_dim * 3, out_features=num_classes)

    def forward(self, head, shell, flipper, label=None):
        head_emb = self.head_net(head)
        shell_emb = self.shell_net(shell)
        flipper_emb = self.flipper_net(flipper)
        combined = torch.cat([head_emb, shell_emb, flipper_emb], dim=1)

        if label is not None:
            logits = self.arcface(combined, label)
            return logits, combined
        else:
            return combined

In [9]:
def train_model(model, dataloader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    for head, shell, flipper, labels in dataloader:
        head, shell, flipper, labels = head.to(device), shell.to(device), flipper.to(device), labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(head, shell, flipper, labels)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


In [10]:
def evaluate_open_set(model, dataloader, device):
    model.eval()
    all_embeddings, all_labels = [], []
    with torch.no_grad():
        for head, shell, flipper, labels in dataloader:
            head, shell, flipper = head.to(device), shell.to(device), flipper.to(device)
            embeddings = model(head, shell, flipper)
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels)
    all_embeddings = torch.cat(all_embeddings)
    all_labels = torch.cat(all_labels)
    sims = F.cosine_similarity(all_embeddings.unsqueeze(1), all_embeddings.unsqueeze(0), dim=2)
    preds = sims.argmax(dim=1)
    acc = (all_labels[preds] == all_labels).float().mean().item()
    return acc


In [None]:
coco = COCO(annotations_path)
train_df = df[df['split_open'] == 'train']
test_df = df[df['split_open'] == 'test']

transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

train_dataset = TurtleReIDSegmentsDataset(train_df, root_dir, coco, transform)
test_dataset = TurtleReIDSegmentsDataset(test_df, root_dir, coco, transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_dataset.identity_to_idx)

model = TurtleReIDModel(embedding_dim=256, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("Training started...")
for epoch in range(5):
    loss = train_model(model, train_loader, optimizer, device)
    print(f"Epoch [{epoch+1}/5], Loss: {loss:.4f}")

print("Evaluating open-set re-ID accuracy...")
acc = evaluate_open_set(model, test_loader, device)
print(f"Open-set matching accuracy: {acc:.4f}")

loading annotations into memory...
Done (t=5.49s)
creating index...
index created!
Downloading: "https://download.pytorch.org/models/swin_b-68c6b09e.pth" to /root/.cache/torch/hub/checkpoints/swin_b-68c6b09e.pth


100%|██████████| 335M/335M [00:06<00:00, 57.5MB/s]


Training started...
