In [1]:
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm import tqdm
import shutil
from datasets import load_dataset
import gc

# Constants
MAX_IMAGES = 2000
BATCH_SIZE = 12
MODEL_NAME = "vit_base_patch14_reg4_dinov2.lvd142m"


In [2]:

tmp_path = Path("./tmp-test-images")
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"
dataset = load_dataset(dataset_name, split="train", streaming=True)

In [3]:

for i, example in enumerate(tqdm(iter(dataset))):  # Stream images to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")


2000it [00:39, 51.13it/s] 


In [4]:

class ImageDataset(Dataset):
    def __init__(self, img_dir, transform):
        self.img_paths = list(Path(img_dir).glob("*.png"))
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        with Image.open(img_path).convert("RGB") as img:
            img_tensor = self.transform(img)
        return img_tensor


In [5]:

@dataclass
class TimmEmbedder:
    model_name: str

    def __post_init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model = self.model.to(self.device).eval()
        cfg = self.model.pretrained_cfg
        self.transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

    def prepare(self, image):
        return self.transform(image).unsqueeze(0).to(self.device)

    def embed(self, img_tensor):
        with torch.no_grad():
            return self.model(img_tensor).cpu().numpy()


In [None]:

embedder = TimmEmbedder(MODEL_NAME)
dataset = ImageDataset(tmp_path, embedder.transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)


: 

In [None]:

embeddings = []

for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
    batch = torch.cat([embedder.prepare(img) for img in batch], dim=0)
    batch_embeddings = embedder.embed(batch)
    embeddings.append(batch_embeddings)
    # if i%25 == 0:
    #     print(torch.cuda.memory_summary())
    if i%5 == 0:
        torch.cuda.empty_cache()
        gc.collect()
        print(f"Size of embeddings: {sum(e.nbytes for e in embeddings)}")


embeddings = np.vstack(embeddings)
print("Embedding shape:", embeddings.shape)


Processing batches:   1%|          | 1/167 [00:15<43:12, 15.62s/it]

Size of embeddings: 36864


Processing batches:   4%|▎         | 6/167 [01:34<42:22, 15.79s/it]

Size of embeddings: 221184


Processing batches:   7%|▋         | 11/167 [02:56<42:23, 16.31s/it]

Size of embeddings: 405504


Processing batches:  10%|▉         | 16/167 [04:19<42:08, 16.74s/it]

Size of embeddings: 589824


Processing batches:  13%|█▎        | 21/167 [05:43<40:37, 16.70s/it]

Size of embeddings: 774144


Processing batches:  16%|█▌        | 26/167 [07:07<39:34, 16.84s/it]

Size of embeddings: 958464


Processing batches:  19%|█▊        | 31/167 [08:32<38:21, 16.92s/it]

Size of embeddings: 1142784


Processing batches:  22%|██▏       | 36/167 [09:56<37:09, 17.02s/it]

Size of embeddings: 1327104


Processing batches:  25%|██▍       | 41/167 [11:20<34:54, 16.63s/it]

Size of embeddings: 1511424


Processing batches:  28%|██▊       | 46/167 [12:43<33:29, 16.61s/it]

Size of embeddings: 1695744


Processing batches:  31%|███       | 51/167 [14:06<32:10, 16.65s/it]

Size of embeddings: 1880064


Processing batches:  34%|███▎      | 56/167 [15:30<31:13, 16.87s/it]

Size of embeddings: 2064384


Processing batches:  37%|███▋      | 61/167 [16:52<29:02, 16.44s/it]

Size of embeddings: 2248704


Processing batches:  40%|███▉      | 66/167 [18:13<27:22, 16.26s/it]

Size of embeddings: 2433024


Processing batches:  43%|████▎     | 71/167 [19:34<25:53, 16.18s/it]

Size of embeddings: 2617344


Processing batches:  46%|████▌     | 76/167 [20:55<24:37, 16.23s/it]

Size of embeddings: 2801664


Processing batches:  49%|████▊     | 81/167 [22:17<23:28, 16.38s/it]

Size of embeddings: 2985984


Processing batches:  51%|█████▏    | 86/167 [23:39<22:05, 16.37s/it]

Size of embeddings: 3170304


Processing batches:  54%|█████▍    | 91/167 [25:01<20:54, 16.50s/it]

Size of embeddings: 3354624


Processing batches:  57%|█████▋    | 96/167 [26:23<19:29, 16.47s/it]

Size of embeddings: 3538944


Processing batches:  60%|█████▉    | 100/167 [27:29<18:29, 16.55s/it]

In [None]:

shutil.rmtree(tmp_path)

In [None]:
sum(e.nbytes for e in embeddings)