In [13]:
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

from inference import StableEmbeddingModel

  from .autonotebook import tqdm as notebook_tqdm


In [15]:

CKPT_PATH = "model.ckpt"
NUM_NEW_CLASSES = 10

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
EPOCHS = 20
LR = 1e-4

# Model

In [11]:
old_state = torch.load(CKPT_PATH, map_location="cpu")

In [17]:
model = StableEmbeddingModel(
    embedding_dim=512,
    num_classes=NUM_NEW_CLASSES,
    pretrained_backbone=True,
)

filtered_state = {
    k: v for k, v in old_state.items()
    if not k.startswith("arcface_head.")
}

missing, unexpected = model.load_state_dict(filtered_state, strict=False)
print("Missing keys (expected: arcface_head.*):", missing)
print("Unexpected keys:", unexpected)

for name, param in model.named_parameters():
    if name.startswith("backbone"):
        param.requires_grad = False
        
model.to(DEVICE)

Loading ViT backbone: beitv2_base_patch16_224.in1k_ft_in22k_in1k
StableEmbeddingModel initialized with ViT backbone: beitv2_base_patch16_224.in1k_ft_in22k_in1k
  Embedding Dim: 512, Num Classes: 10
  ArcFace s: 64.0, m: 0.5
  Backbone out features (ViT embed_dim): 768
  BN in embedding: False, Dropout in embedding: 0.11
Missing keys (expected: arcface_head.*): ['arcface_head.weight', 'arcface_head.cos_m', 'arcface_head.sin_m', 'arcface_head.th', 'arcface_head.mm', 'arcface_head.eps']
Unexpected keys: []


StableEmbeddingModel(
  (backbone): Beit(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=False)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (drop1): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (fc2): Linear(in_features=3072, out_features=768

In [18]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR,
    weight_decay=1e-4,
)

In [None]:
class FishDataset(Dataset):
    def __init__(self, image_root, annotations, transform=None):
        """
        image_root: folder with images
        annotations: list of (image_filename, class_id)
        """
        self.image_root = image_root
        self.samples = annotations
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.image_root, img_name)

        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        return img, label

In [None]:
df_annot = pd.read_csv("my_annotations.csv")

label_to_id = {label: idx for idx, label in enumerate(sorted(df_annot["label"].unique()))}

train_annotations = [
    (row["image_file"], label_to_id[row["label"]]) 
    for _, row in df.iterrows()
]
val_annotations = []

train_dataset = FishDataset("data_images/train", train_annotations)
val_dataset   = FishDataset("data_images/val",   val_annotations)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4)


In [20]:
def train_one_epoch(model, loader, optimizer, epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        # Manual forward using modelâ€™s components
        # 1) ViT features
        features = model.backbone_feature_extractor(images)  # [B, N+1, D]

        # 2) Remove CLS token, keep patches
        if hasattr(model.backbone, "cls_token"):
            patch_tokens = features[:, 1:, :]
        else:
            patch_tokens = features

        # 3) Attention pooling
        pooled, _ = model.pooling(patch_tokens, return_attention_map=True)

        # 4) Embedding head + L2-normalization
        emb_raw = model.embedding_fc(pooled)
        emb_norm = F.normalize(emb_raw, p=2, dim=1)

        # 5) ArcFace logits (margin injected using labels)
        logits = model.arcface_head(emb_norm, labels)

        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)

    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    acc = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"Epoch {epoch} - Train loss: {avg_loss:.4f}, acc: {acc:.4f}")


In [None]:

@torch.no_grad()
def validate(model, loader, epoch):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        features = model.backbone_feature_extractor(images)
        if hasattr(model.backbone, "cls_token"):
            patch_tokens = features[:, 1:, :]
        else:
            patch_tokens = features

        pooled, _ = model.pooling(patch_tokens, return_attention_map=True)
        emb_raw = model.embedding_fc(pooled)
        emb_norm = F.normalize(emb_raw, p=2, dim=1)
        logits = model.arcface_head(emb_norm, labels)

        loss = F.cross_entropy(logits, labels)

        total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)

    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    acc = total_correct / total_samples if total_samples > 0 else 0.0
    print(f"Epoch {epoch} - Val loss: {avg_loss:.4f}, acc: {acc:.4f}")

In [None]:
for epoch in range(1, EPOCHS + 1):
    train_one_epoch(model, train_loader, optimizer, epoch)
    validate(model, val_loader, epoch)

# Save your new transfer-learned weights
torch.save(model.state_dict(), "model_transfer_newdataset.pt")