In [None]:
import os
import torch
import math
import random
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import transformers
from torchvision import datasets
from dataclasses import dataclass
from typing import List, Dict, Any
from torch.utils.data import DataLoader, random_split
from transformers import AutoImageProcessor, AutoModel, AutoConfig
from dinov3_linear import DinoV3Linear

In [2]:
data_dir = "./downloads/birds-200-species/CUB_200_2011/images"
full_dataset = datasets.ImageFolder(root=data_dir)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

num_classes = len(full_dataset.classes)
id2label = {i: c for i, c in enumerate(full_dataset.classes)}
label2id = {c: i for i, c in id2label.items()}

In [10]:
ckpt_path = "./weights/model_best.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(ckpt_path, map_location=device)

ProcessorClass = getattr(transformers, ckpt["config"]["image_processor"]["image_processor_type"])
image_processor = ProcessorClass(**ckpt["config"]["image_processor"])
backbone = AutoModel.from_config(AutoConfig.for_model(**ckpt["config"]["backbone"]))
model = DinoV3Linear(
    backbone=backbone,
    num_classes=len(ckpt["config"]["classes"]),
    freeze_backbone=ckpt["config"].get("freeze_backbone", True),
).to(device)

model.load_state_dict(ckpt["model_state_dict"])
model = model.eval()

In [4]:
correct, total = 0, 0
with torch.no_grad():
    for img, label in tqdm(train_dataset):
        if isinstance(img, Image.Image):
            img = img.convert("RGB")

        inputs = image_processor(images=img, return_tensors="pt").to(device)
        logits = model(inputs["pixel_values"])
        probs = torch.softmax(logits, dim=-1)
        pred = probs.argmax(dim=-1).item()
        correct += int(pred == label)
        total += 1

accuracy = correct / total
print(f"Train Accuracy: {accuracy:.4f}")

100%|██████████| 9430/9430 [02:58<00:00, 52.84it/s]

Train Accuracy: 0.9696





In [5]:
correct, total = 0, 0
with torch.no_grad():
    for img, label in tqdm(val_dataset):
        if isinstance(img, Image.Image):
            img = img.convert("RGB")

        inputs = image_processor(images=img, return_tensors="pt").to(device)
        logits = model(inputs["pixel_values"])
        probs = torch.softmax(logits, dim=-1)
        pred = probs.argmax(dim=-1).item()
        correct += int(pred == label)
        total += 1

accuracy = correct / total
print(f"Val Accuracy: {accuracy:.4f}")

100%|██████████| 2358/2358 [00:44<00:00, 53.44it/s]

Val Accuracy: 0.9656



