In [None]:
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm import tqdm
import shutil
from datasets import load_dataset
import gc

# Constants
MAX_IMAGES = 2000
BATCH_SIZE = 12
MODEL_NAME = "vit_base_patch14_reg4_dinov2.lvd142m"


In [None]:

tmp_path = Path("./tmp-test-images")
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"
dataset = load_dataset(dataset_name, split="train", streaming=True)

In [None]:

for i, example in enumerate(tqdm(iter(dataset))):  # Stream images to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")


In [None]:

class ImageDataset(Dataset):
    def __init__(self, img_dir, transform):
        self.img_paths = list(Path(img_dir).glob("*.png"))
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        with Image.open(img_path).convert("RGB") as img:
            img_tensor = self.transform(img)
        return img_tensor


In [None]:

@dataclass
class TimmEmbedder:
    model_name: str

    def __post_init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model = self.model.to(self.device).eval()
        cfg = self.model.pretrained_cfg
        self.transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

    def prepare(self, image):
        return self.transform(image).unsqueeze(0).to(self.device)

    def embed(self, img_tensor):
        with torch.no_grad():
            return self.model(img_tensor).cpu().numpy()


In [None]:

embedder = TimmEmbedder(MODEL_NAME)
dataset = ImageDataset(tmp_path, embedder.transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:

embeddings = []

for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
    batch = torch.cat([embedder.prepare(img) for img in batch], dim=0)
    batch_embeddings = embedder.embed(batch)
    embeddings.append(batch_embeddings)
    # if i%25 == 0:
    #     print(torch.cuda.memory_summary())
    if i%5 == 0:
        torch.cuda.empty_cache()
        gc.collect()
        print(f"Size of embeddings: {sum(e.nbytes for e in embeddings)}")


embeddings = np.vstack(embeddings)
print("Embedding shape:", embeddings.shape)


In [None]:

shutil.rmtree(tmp_path)

In [None]:
sum(e.nbytes for e in embeddings)