In [1]:
import glob
import os
import gc
import shutil
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import timm
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

from itertools import chain

In [None]:

BATCH_SIZE = 8
MODEL_NAME = "vit_large_patch14_reg4_dinov2.lvd142m"
IMAGE_GLOB = None
IMAGES_FOLDER = "./tmp-test-images"

TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 2_000

with vit_base_patch14:

num_images | batch_size | optimize | time |
-----------|------------|----------|------|
500        |         32 | False    | 10:07 
200        |         32 | False    | 04:50 
2000       |         32 | False    | 41:00 
50         |         32 | Static   | 01:14
50         |         32 | Dynamic  | 01:14
500        |         32 | Static   | 09:33 
2000       | 32 (fixed) | Static   | 36:22
2000       | 16 (fixed) | Static   | 36:32
2000       |  4 (fixed) | Static   | 39:17
2000       | 128 (fixd) | Static   | OOM
2000lrg    |  16 (fixd) | Static   | -

In [3]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None, max_images: int = 50) -> None:
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        img.save(dir / f"{i}.png")

    del dataset
    gc.collect()

    return None

In [4]:

tmp_path = Path(IMAGES_FOLDER)
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dl_hf_images(dir=tmp_path, max_images=NUM_TEST_IMAGES)


100%|██████████| 2000/2000 [00:35<00:00, 56.95it/s] 


In [5]:

def get_file_list(source: str | list[str], pattern: str = IMAGE_GLOB) -> list[Path]:
    """
    Given a folder path, a glob pattern, or a filelist, return a list of image file paths.
    If source is a directory, a glob is run using the provided pattern.
    If source is a string containing a wildcard, glob is applied.
    Otherwise, if it's a list, it is returned directly.
    """
    if isinstance(source, list):
        return [Path(s) for s in source]
    elif Path(source).is_dir():
        patterns = ["*.png", "*.jpg", "*.jpeg"]
        return list(chain.from_iterable([Path(source).glob(p) for p in patterns]))
    elif isinstance(source, str) and '*' in source:
        return [Path(p) for p in glob.glob(source)]
    else:
        return [source]

In [6]:

@dataclass
class Embedder:
    model_name: str
    device: torch.device = field(default_factory=lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model: torch.nn.Module = field(init=False)
    transform: callable = field(init=False)

    def __post_init__(self):
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.eval()
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model = torch.compile(self.model, dynamic=True)
        # Resolve config removes unneeded fields before create_transform
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Given a batch of pre-transformed images, compute pooled embeddings.
        The batch is moved to the proper device (with channels_last format) and processed in inference mode.
        """
        batch_imgs = batch_imgs.to(self.device, non_blocking=True)
        batch_imgs = batch_imgs.contiguous(memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda"):
                return self.model(batch_imgs)
        else:
            # autocast can be comically slow for some CPU setups (PyTorch issue #118499)
            return self.model(batch_imgs)

In [7]:

class ImageListIterator(Dataset):
    def __init__(self, filelist: list[Path], transform: callable):
        self.filelist = filelist
        self.transform = transform

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx: int):
        image = Image.open(self.filelist[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [8]:

def pad_to_batch_size(batch, batch_size):
    """Pads the batch to batch_size with zeros"""
    orig_size = len(batch)
    batch = torch.stack(batch)
    if orig_size < batch_size:
        pad_tensor = torch.zeros((batch_size - orig_size, *batch.shape[1:]),
                                 dtype=batch.dtype, device=batch.device)
        batch = torch.cat([batch, pad_tensor], dim=0)
    return batch, orig_size

In [None]:

def compute_embeddings(model_name: str, filelist: list[str], batch_size: int = BATCH_SIZE) -> list[np.ndarray]:
    """
    Given a model name and a filelist (list of image paths), this function computes and returns a list
    of embeddings (one per image). The function instantiates an Embedder, builds a dataset and dataloader,
    and processes images in batches.
    """
    embedder = Embedder(model_name=model_name)

    dataset = ImageListIterator(filelist, embedder.transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=lambda b: pad_to_batch_size(b, batch_size),
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    for i, (batch_imgs, actual_batch_size) in enumerate(tqdm(dataloader)):
        emb = embedder.embed(batch_imgs).cpu().numpy()
        emb = emb[:actual_batch_size, ...]

        if i == 0:
            embeddings = emb
            print(f"Shape of embedding for one batch: {emb.shape}")
        else:
            embeddings = np.concatenate((embeddings, emb), axis=0)

    return embeddings

: 

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

file_list = get_file_list(IMAGES_FOLDER)

with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    with record_function("model_inference"):

        embeddings = compute_embeddings(MODEL_NAME, file_list, BATCH_SIZE)

print(f"Processed {len(file_list)} images. Got {len(embeddings)} embeddings.")

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

  1%|          | 1/125 [01:28<3:02:04, 88.10s/it]

Shape of embedding for one batch: (16, 1024)


 87%|████████▋ | 109/125 [1:46:58<15:59, 59.96s/it] 

In [None]:
prof.export_chrome_trace(f"trace_{MODEL_NAME}_{NUM_TEST_IMAGES}x{BATCH_SIZE}.json")

SyntaxError: f-string: single '}' is not allowed (965503420.py, line 1)