
# 03 — DINOv2 Embedding Cache

**Goal:** Compute DINOv2 features for each **tile** and save to Parquet shards.
- We freeze DINOv2 and only train a small MIL head later.


In [None]:

%pip -q install --extra-index-url https://download.pytorch.org/whl/cu121   torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0
%pip -q install timm==1.0.9 numpy pandas pyarrow pillow tqdm


In [None]:

import os, math, numpy as np, pandas as pd, torch, timm
from pathlib import Path
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

BASE = Path('/content')  # change if needed
IM_TILE_DIR = BASE/'data/tiles'
CACHE_DIR   = BASE/'cache/embeddings'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

ENC_NAME = "vit_small_patch14_dinov2.lvd142m"
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = timm.create_model(ENC_NAME, pretrained=True, num_classes=0, global_pool="avg")
for p in encoder.parameters(): p.requires_grad = False
encoder.eval().to(device)
FEAT_DIM = encoder.num_features
print("Encoder:", ENC_NAME, "Feat dim:", FEAT_DIM, "Device:", device)

preproc = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])


In [None]:

def embed_paths(paths, bs=128):
    feats = []
    with torch.no_grad():
        for i in range(0, len(paths), bs):
            batch = paths[i:i+bs]
            imgs = [preproc(Image.open(p).convert('RGB')) for p in batch]
            x = torch.stack(imgs).to(device)
            f = encoder(x)
            feats.append(f.detach().cpu().numpy())
    return np.concatenate(feats, axis=0) if feats else np.zeros((0, FEAT_DIM), dtype=np.float32)


In [None]:

import pyarrow as pa, pyarrow.parquet as pq

def process_split(name):
    tiles_df = pd.read_parquet(IM_TILE_DIR/f'tiles_{name}.parquet')
    # shard by chunks of N tiles to be resume-friendly
    N = 5000
    for start in tqdm(range(0, len(tiles_df), N), desc=f"embed {name}"):
        end = min(len(tiles_df), start+N)
        shard_path = CACHE_DIR/f'emb_{name}_{start:06d}_{end:06d}.parquet'
        if shard_path.exists():
            continue  # resume
        batch = tiles_df.iloc[start:end].copy()
        feats = embed_paths(batch['tile_path'].tolist(), bs=128)
        # build table
        import numpy as np
        emb_cols = {f"emb_{i}": feats[:,i] for i in range(feats.shape[1])}
        shard = pd.concat([batch[['image_id','tile_id','tile_path','label']].reset_index(drop=True),
                           pd.DataFrame(emb_cols)], axis=1)
        table = pa.Table.from_pandas(shard)
        pq.write_table(table, shard_path)
    print("Done", name)

for split in ['train','val','test']:
    process_split(split)
