In [1]:
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm import tqdm
import shutil
from datasets import load_dataset
import gc
import torch.profiler as profiler


In [2]:

# Constants
MAX_IMAGES = 50
BATCH_SIZE = 12
MODEL_NAME = "vit_base_patch14_reg4_dinov2.lvd142m"


In [3]:

tmp_path = Path("./tmp-test-images")
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"
dataset = load_dataset(dataset_name, split="train", streaming=True)



In [4]:

for i, example in enumerate(tqdm(iter(dataset))):  # Stream images to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")

del dataset
gc.collect()

50it [00:04, 10.52it/s]


8

In [5]:

class ImageDataset(Dataset):
    def __init__(self, img_dir, transform):
        self.img_paths = list(Path(img_dir).glob("*.png"))
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        with Image.open(img_path).convert("RGB") as img:
            img_tensor = self.transform(img)
        return img_tensor


In [6]:

@dataclass
class TimmEmbedder:
    model_name: str

    def __post_init__(self):
        self.device = torch.device("cpu")  # Ensure CPU inference
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model = self.model.to(self.device).eval()
        cfg = self.model.pretrained_cfg
        self.transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

    def prepare(self, image):
        return self.transform(image).unsqueeze(0).to(self.device)

    def embed(self, img_tensor):
        with torch.no_grad():
            return self.model(img_tensor).cpu().numpy()


In [7]:

embedder = TimmEmbedder(MODEL_NAME)
dataset = ImageDataset(tmp_path, embedder.transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

embeddings = []


In [8]:

with profiler.profile(activities=[profiler.ProfilerActivity.CPU],
                      record_shapes=True, profile_memory=True) as prof:
    for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
        batch = torch.cat([embedder.prepare(img) for img in batch], dim=0)
        batch_embeddings = embedder.embed(batch)
        embeddings.append(batch_embeddings)
        
        if i % 5 == 0:
            gc.collect()

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

embeddings = np.vstack(embeddings)
print("Embedding shape:", embeddings.shape)

shutil.rmtree(tmp_path)

# sum(e.nbytes for e in embeddings)

Processing batches: 100%|██████████| 5/5 [01:12<00:00, 14.43s/it]


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::addmm        59.93%       42.810s        65.36%       46.691s     194.545ms      21.23 Gb      21.23 Gb           240  
                                             aten::gelu         3.69%        2.635s         3.69%        2.635s      43.917ms       9.43 Gb       9.43 Gb            60  
                                            aten::empty         0.01%       7.142ms         0.01%       7.142ms      11.595us       5.46 Gb       5.46

In [9]:
prof.export_chrome_trace("trace.json")

In [None]:
import objgraph
objgraph.show_backrefs([embedder], max_depth=3)  # Visualize lingering references

Graph written to /tmp/objgraph-mooah0r8.dot (17 nodes)
Spawning graph viewer (xdot)


