In [1]:
import os
import torch
import math
import json
import random
import trackio
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from torchvision import datasets
from dataclasses import dataclass
from typing import List, Dict, Any
from torch.utils.data import DataLoader, random_split
from transformers import AutoImageProcessor, AutoModel, AutoConfig, get_cosine_schedule_with_warmup
from dinov3_linear import DinoV3Linear

  from .autonotebook import tqdm as notebook_tqdm


Load and split the dataset for training

In [2]:
data_dir = "./downloads/birds-200-species/CUB_200_2011/images"
full_dataset = datasets.ImageFolder(root=data_dir)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

num_classes = len(full_dataset.classes)
id2label = {i: c for i, c in enumerate(full_dataset.classes)}
label2id = {c: i for i, c in id2label.items()}

Download the pretrained weights from here:  
https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m

In [None]:
MODEL_NAME = "./downloads/dinov3-vitb16-pretrain-lvd1689m"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
backbone = AutoModel.from_pretrained(MODEL_NAME)
image_processor_config = json.loads(image_processor.to_json_string())
backbone_config = json.loads(AutoConfig.from_pretrained(MODEL_NAME).to_json_string()) # creates dictionary with all model detals

freeze_backbone = True
model = DinoV3Linear(backbone, num_classes, freeze_backbone=freeze_backbone).to(device)

In [None]:
BATCH_SIZE = 16
NUM_WORKERS = min(8, os.cpu_count() or 2)
EPOCHS = 15
LR = 5e-4
WEIGHT_DECAY = 1e-4
WARMUP_RATIO = 0.05
CHECKPOINT_DIR = "./weights"
EVAL_EVERY_STEPS = 100

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

@dataclass
class Collator:
    processor: AutoImageProcessor

    def __call__(self, batch):
        # unpack (image, label) tuples
        images, labels = zip(*batch)

        rgb_images = [img.convert("RGB") if isinstance(img, Image.Image) else img for img in images]

        inputs = self.processor(images=rgb_images, return_tensors="pt")
        labels = torch.tensor(labels, dtype=torch.long)

        return {"pixel_values": inputs["pixel_values"], "labels": labels}
    
collate_fn = Collator(image_processor)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collate_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collate_fn,
)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = EPOCHS * math.ceil(len(train_loader))
warmup_steps = int(WARMUP_RATIO * total_steps)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


In [5]:
def evaluate() -> Dict[str, float]:
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for batch in val_loader:
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            labels = batch["labels"].to(device, non_blocking=True)
            logits = model(pixel_values)
            loss = criterion(logits, labels)
            loss_sum += loss.item() * labels.size(0)
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return {
        "val_loss": loss_sum / max(total, 1),
        "val_acc": correct / max(total, 1),
    }

In [None]:
best_acc = 0.0
global_step = 0

trackio.init(project="dinov3", config={
            "epochs": EPOCHS,
            "learning_rate": LR,
            "batch_size": BATCH_SIZE
        })

for epoch in range(1, EPOCHS + 1):
    model.train()
    model.backbone.eval()  

    running_loss = 0.0
    for i, batch in enumerate(train_loader, start=1):
        pixel_values = batch["pixel_values"].to(device, non_blocking=True)
        labels = batch["labels"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(pixel_values)
        loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item()
        global_step += 1

        if global_step % EVAL_EVERY_STEPS == 0:
            metrics = evaluate()
            print(
                f"[epoch {epoch} | step {global_step}] "
                f"train_loss={running_loss / EVAL_EVERY_STEPS:.4f} "
                f"val_loss={metrics['val_loss']:.4f} val_acc={metrics['val_acc']*100:.2f}%"
            )
            running_loss = 0.0

            trackio.log(
                    {
                        "epoch": epoch,
                        "val_acc": best_acc,
                    }
                )

            if metrics["val_acc"] > best_acc:
                best_acc = metrics["val_acc"]
                ckpt_path = os.path.join(CHECKPOINT_DIR, f"model_best.pt")
                torch.save(
                    {
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "config": {
                            "model_name": MODEL_NAME,
                            "classes": full_dataset.classes,
                            "backbone": backbone_config,
                            "image_processor": image_processor_config,
                            "freeze_backbone": freeze_backbone,
                        },
                        "step": global_step,
                        "epoch": epoch,
                    },
                    ckpt_path,
                )

    metrics = evaluate()
    print(
        f"END EPOCH {epoch}: val_loss={metrics['val_loss']:.4f} val_acc={metrics['val_acc']*100:.2f}% "
        f"(best_acc={best_acc*100:.2f}%)"
    )
    trackio.finish()