In [1]:
import glob
import os
import gc
import shutil
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import timm
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import daft
from itertools import chain

In [2]:

BATCH_SIZE = 8
MODEL_NAME = "vit_large_patch14_reg4_dinov2.lvd142m"
IMAGE_GLOB = None
IMAGES_FOLDER = "./tmp-test-images"

TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 50

nice_models = [
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m",
"aimv2_large_patch14_224.apple_pt_dist"
]

with vit_base_patch14 and torch dataloader:

num_images | batch_size | optimize | time |
-----------|------------|----------|------|
500        |         32 | False    | 10:07 
200        |         32 | False    | 04:50 
2000       |         32 | False    | 41:00 
50         |         32 | Static   | 01:14
50         |         32 | Dynamic  | 01:14
500        |         32 | Static   | 09:33 
2000       | 32 (fixed) | Static   | 36:22
2000       | 16 (fixed) | Static   | 36:32
2000       |  4 (fixed) | Static   | 39:17
2000       | 128 (fixd) | Static   | OOM
2000lrg    |  16 (fixd) | Static   | -

In [3]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None, max_images: int = 50) -> None:
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        img.save(dir / f"{i}.png")

    del dataset
    gc.collect()

    return None

In [4]:

tmp_path = Path(IMAGES_FOLDER)
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dl_hf_images(dir=tmp_path, max_images=NUM_TEST_IMAGES)


100%|██████████| 50/50 [00:17<00:00,  2.90it/s]


In [6]:

@dataclass
class Embedder:
    model_name: str
    device: torch.device = field(default_factory=lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model: torch.nn.Module = field(init=False)
    transform: callable = field(init=False)

    def __post_init__(self):
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.eval()
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model = torch.compile(self.model, dynamic=True)
        # Resolve config removes unneeded fields before create_transform
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Given a batch of pre-transformed images, compute pooled embeddings.
        The batch is moved to the proper device (with channels_last format) and processed in inference mode.
        """
        batch_imgs = batch_imgs.to(self.device, non_blocking=True)
        batch_imgs = batch_imgs.contiguous(memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda"):
                return self.model(batch_imgs)
        else:
            # autocast can be comically slow for some CPU setups (PyTorch issue #118499)
            return self.model(batch_imgs)

In [26]:
@daft.udf(return_dtype=daft.DataType.python())
class TransformImageCol:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = Embedder(self.model_name)

    def __call__(self, batch_images) -> list:
        return [self.embedder.transform(Image.fromarray(im)) for im in batch_images.to_pylist()]

In [27]:
imglob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(imglob).with_column_renamed("path", "path_full_img")
images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null")
                                 )
images_df = images_df.where(images_df["image"].not_null())

TransformImForModel = TransformImageCol.with_init_args(model_name=MODEL_NAME)

images_df = images_df.with_column("image_transformed", TransformImForModel(daft.col("image"))
                                  ).exclude("image", "num_rows")

images_df.show(1)

path_full_img Utf8,size Int64,image_transformed Python
file://tmp-test-images/43.png,52246,"tensor([[[ 2.0263, 1.9749, 1.8550, ..., 1.8037, 1.8208, 1.8208],  [ 1.9920, 1.9064, 1.6495, ..., 1.2728, 1.2899, 1.2899],  [ 1.8550, 1.7009, 1.3070, ..., 0.6563, 0.6563, 0.6563],  ...,  [ 1.1015, 0.7419, -0.1486, ..., -0.6623, -0.6965, -0.6965],  [ 1.2557, 0.8961, 0.0056, ..., -0.6965, -0.7308, -0.7308],  [ 1.3413, 0.9817, 0.0912, ..., -0.7822, -0.7993, -0.7993]],  [[ 2.3761, 2.3410, 2.2360, ..., 2.2535, 2.2535, 2.2535],  [ 2.4286, 2.3761, 2.1310, ..., 1.8158, 1.8158, 1.8158],  [ 2.3936, 2.2885, 1.8859, ..., 1.2731, 1.2731, 1.2731],  ...,  [ 2.1660, 1.7808, 0.8354, ..., -0.3200, -0.3025, -0.3025],  [ 2.3235, 1.9209, 0.9930, ..., -0.3025, -0.2850, -0.2850],  [ 2.3761, 1.9909, 1.0805, ..., -0.2850, -0.2850, -0.2850]],  [[ 2.6400, 2.6400, 2.6051, ..., 2.5180, 2.5180, 2.5180],  [ 2.6400, 2.6226, 2.5354, ..., 2.2217, 2.2217, 2.2217],  [ 2.6400, 2.5877, 2.3611, ..., 1.7860, 1.7860, 1.7860],  ...,  [ 2.6400, 2.4483, 1.5245, ..., 0.1302, 0.1476, 0.1476],  [ 2.6400, 2.4657, 1.6117, ..., 0.1825, 0.1999, 0.1999],  [ 2.6400, 2.4831, 1.6814, ..., 0.2348, 0.2522, 0.2522]]])"


In [28]:

def pad_to_batch_size_from_dict(batch_dict, batch_size):
    """Pads the batch to batch_size with zeros"""
    batch = batch_dict["image_transformed"]
    orig_size = len(batch)
    batch = torch.stack(batch)
    if orig_size < batch_size:
        pad_tensor = torch.zeros((batch_size - orig_size, *batch.shape[1:]),
                                 dtype=batch.dtype, device=batch.device)
        batch = torch.cat([batch, pad_tensor], dim=0)
    return batch, orig_size

In [29]:

def compute_embeddings(model_name:
                       str, dataset: torch.utils.data.IterableDataset,
                       batch_size: int = BATCH_SIZE) -> list[np.ndarray]:
    """
    Given a model name and a filelist (list of image paths), this function computes and returns a list
    of embeddings (one per image). The function instantiates an Embedder, builds a dataset and dataloader,
    and processes images in batches.
    """
    embedder = Embedder(model_name=model_name)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=lambda b: pad_to_batch_size_from_dict(b, batch_size),
        shuffle=False,
    )

    for i, (batch_imgs, actual_batch_size) in enumerate(tqdm(dataloader)):
        emb = embedder.embed(batch_imgs).cpu().numpy()
        emb = emb[:actual_batch_size, ...]

        if i == 0:
            embeddings = emb
            print(f"Shape of embedding for one batch: {emb.shape}")
        else:
            embeddings = np.concatenate((embeddings, emb), axis=0)

    return embeddings

In [30]:
from torch.profiler import profile, record_function, ProfilerActivity

images_dataset = images_df.to_torch_iter_dataset()

with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    with record_function("model_inference"):

        embeddings = compute_embeddings(MODEL_NAME, images_dataset, BATCH_SIZE)

# print(f"Processed {len(images_dataset.count_rows())} images.")
print(f"Got {len(embeddings)} embeddings.")

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

: 

In [None]:
prof.export_chrome_trace(f"daftiter_trace_{MODEL_NAME}_{NUM_TEST_IMAGES}x{BATCH_SIZE}.json")

SyntaxError: f-string: single '}' is not allowed (965503420.py, line 1)