In [12]:
from transformers import AutoImageProcessor, AutoModel
import torch
from torch.utils.data import DataLoader
from src.model import SiameseDino
from src.train import evaluate
from src.data import CachedCollection, LazyLoadCollection, create_dataset_splits, make_transform
from pathlib import Path

In [5]:
base_model_name = "facebook/dinov3-vitb16-pretrain-lvd1689m"
processor = AutoImageProcessor.from_pretrained(base_model_name)
dinov3_model = AutoModel.from_pretrained(
    base_model_name,
    dtype=torch.float32
)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [23]:
hidden_dim = 512
output_dim = 256
device = torch.device("cuda" if torch.backends.mps.is_available() else "cpu")
siamese_model = SiameseDino(dinov3_model, hidden_dim, output_dim)
siamese_model.load_state_dict(torch.load("../model_checkpoints/legendary-rain-171.pth"))
_ = siamese_model.to(device)

In [9]:
original_root_path = Path("../data/original_data")
augmented_root_path = Path("../data/augmented_data16")
splits = create_dataset_splits(original_root_path, augmented_root_path, 0, 1, 1)

Found 102 total classes.
Splitting classes: 1 Train / 101 Val / 0 Test

Processing Train classes...
Processing Validation classes...
Processing Test classes...

--- Data Split Summary ---
Training Loader:      16 samples from 1 classes (Augmented)
Gallery Loader:      307 samples from 102 classes (Original)
Val Query Loader:    101 samples from 101 classes (Original)
Test Query Loader:     0 samples from 0 classes (Original)


In [13]:
gallery = CachedCollection(splits["gallery"][0], splits["gallery"][1], make_transform())
queries = CachedCollection(splits["val_query"][0], splits["val_query"][1], make_transform())
gallery_dataloader = DataLoader(gallery, batch_size=32)
queries_dataloader = DataLoader(queries, batch_size=32)

In [None]:
#magic field
print(evaluate(siamese_model, processor, gallery_dataloader, queries_dataloader, [1, 3, 5]))

{'recall@1': 0.8316831683168316, 'recall@3': 0.9207920792079208, 'recall@5': 0.9504950495049505}


In [None]:
#leafy sunset
print(evaluate(siamese_model, processor, gallery_dataloader, queries_dataloader, [1, 3, 5]))

{'recall@1': 0.8118811881188119, 'recall@3': 0.8811881188118812, 'recall@5': 0.9306930693069307}


In [19]:
#hearty lion
print(evaluate(siamese_model, processor, gallery_dataloader, queries_dataloader, [1, 3, 5]))

{'recall@1': 0.8514851485148515, 'recall@3': 0.9108910891089109, 'recall@5': 0.9504950495049505}


In [21]:
#misunderstood thunder
print(evaluate(siamese_model, processor, gallery_dataloader, queries_dataloader, [1, 3, 5]))

{'recall@1': 0.8811881188118812, 'recall@3': 0.9306930693069307, 'recall@5': 0.9603960396039604}


In [24]:
#legendary rain
print(evaluate(siamese_model, processor, gallery_dataloader, queries_dataloader, [1, 3, 5]))

{'recall@1': 0.801980198019802, 'recall@3': 0.9108910891089109, 'recall@5': 0.9504950495049505}
