In [1]:
import gc
import shutil
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import timm
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import daft

In [None]:

BATCH_SIZE = 64
MODEL_NAME = "vit_base_patch14_reg4_dinov2.lvd142m"
IMAGE_GLOB = None
IMAGES_FOLDER = "./tmp-test-images"

TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 2000

nice_models = [
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m",
"aimv2_large_patch14_224.apple_pt_dist"
]

with vit_base_patch14 and torch dataloader:

num_images | batch_size | optimize | time |
-----------|------------|----------|------|
500        |         32 | False    | 10:07 
200        |         32 | False    | 04:50 
2000       |         32 | False    | 41:00 
50         |         32 | Static   | 01:14
50         |         32 | Dynamic  | 01:14
500        |         32 | Static   | 09:33 
2000       | 32 (fixed) | Static   | 36:22
2000       | 16 (fixed) | Static   | 36:32
2000       |  4 (fixed) | Static   | 39:17
2000       | 128 (fixd) | Static   | OOM
2000lrg    |  16 (fixd) | Static   | -

In [3]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None, max_images: int = 50) -> None:
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        img.save(dir / f"{i}.png")

    del dataset
    gc.collect()

    return None

In [4]:

tmp_path = Path(IMAGES_FOLDER)
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dl_hf_images(dir=tmp_path, max_images=NUM_TEST_IMAGES)


100%|██████████| 2000/2000 [00:38<00:00, 51.76it/s] 


In [5]:

@dataclass
class Embedder:
    model_name: str
    device: torch.device = field(default_factory=lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model: torch.nn.Module = field(init=False)
    transform: callable = field(init=False)

    def __post_init__(self):
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.eval()
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model = torch.compile(self.model, dynamic=True)
        # Resolve config removes unneeded fields before create_transform
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Given a batch of pre-transformed images, compute pooled embeddings.
        The batch is moved to the proper device (with channels_last format) and processed in inference mode.
        """
        batch_imgs = batch_imgs.to(self.device, non_blocking=True)
        batch_imgs = batch_imgs.contiguous(memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda"):
                return self.model(batch_imgs)
        else:
            # autocast can be comically slow for some CPU setups (PyTorch issue #118499)
            return self.model(batch_imgs)

In [6]:
@daft.udf(return_dtype=daft.DataType.python())
class TransformImageCol:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = Embedder(self.model_name)

    def __call__(self, batch_images) -> list:
        return [self.embedder.transform(Image.fromarray(im)) for im in batch_images.to_pylist()]

In [7]:
imglob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(imglob).with_column_renamed("path", "path_full_img")
images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null")
                                 )
images_df = images_df.where(images_df["image"].not_null())

TransformImForModel = TransformImageCol.with_init_args(model_name=MODEL_NAME)

images_df = images_df.with_column("image_transformed", TransformImForModel(daft.col("image"))
                                  ).exclude("image", "num_rows")

images_df.show(1)

path_full_img Utf8,size Int64,image_transformed Python
file://tmp-test-images/1354.png,40149,"tensor([[[-1.2959, -1.2617, -1.2103, ..., -1.8268, -1.8268, -1.8268],  [-1.3302, -1.2445, -1.1589, ..., -1.8610, -1.8610, -1.8439],  [-1.3130, -1.2445, -1.1932, ..., -1.8610, -1.8610, -1.8439],  ...,  [ 1.3242, 1.1872, 0.8789, ..., -0.3883, -0.3369, -0.2856],  [ 1.1872, 1.0502, 0.7933, ..., -0.2856, -0.2684, -0.2342],  [ 1.2557, 0.9988, 0.5878, ..., -0.0972, -0.0972, -0.1486]],  [[-0.9853, -0.9503, -0.9678, ..., -1.6856, -1.6681, -1.6506],  [-0.9503, -0.8627, -0.8277, ..., -1.6856, -1.6681, -1.6506],  [-0.9153, -0.8452, -0.8452, ..., -1.6856, -1.6681, -1.6506],  ...,  [ 1.3606, 1.1506, 0.8529, ..., -0.3550, -0.3025, -0.2500],  [ 1.2206, 1.0105, 0.7654, ..., -0.2675, -0.2500, -0.2150],  [ 1.2906, 0.9755, 0.5553, ..., -0.0924, -0.0749, -0.1275]],  [[-1.0376, -1.0550, -1.0201, ..., -1.3339, -1.3164, -1.3164],  [-1.0201, -0.9678, -0.8807, ..., -1.3513, -1.3339, -1.2990],  [-0.9853, -0.9504, -0.8807, ..., -1.3687, -1.3513, -1.3164],  ...,  [ 1.4374, 1.2805, 1.0888, ..., -0.1487, -0.0790, -0.0441],  [ 1.3154, 1.1411, 0.9842, ..., -0.0267, -0.0267, 0.0082],  [ 1.4025, 1.1237, 0.7576, ..., 0.1651, 0.1651, 0.0953]]])"


In [8]:

def compute_embeddings(model_name:
                       str, dataset: torch.utils.data.IterableDataset,
                       batch_size: int = BATCH_SIZE) -> list[np.ndarray]:
    """
    Given a model name and a filelist (list of image paths), this function computes and returns a list
    of embeddings (one per image). The function instantiates an Embedder, builds a dataset and dataloader,
    and processes images in batches.
    """
    embedder = Embedder(model_name=model_name)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    for i, batch_images in enumerate(tqdm(dataloader, unit_scale=BATCH_SIZE)):
        emb = embedder.embed(batch_images["image_transformed"]).cpu().numpy()

        if i == 0:
            embeddings = emb
            print(f"Shape of embedding for one batch: {emb.shape}")
        else:
            embeddings = np.concatenate((embeddings, emb), axis=0)

    return embeddings

In [9]:
from torch.profiler import profile, record_function, ProfilerActivity

images_dataset = images_df.to_torch_iter_dataset()

In [10]:


with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    with record_function("model_inference"):

        embeddings = compute_embeddings(MODEL_NAME, images_dataset, BATCH_SIZE)

# print(f"Processed {len(images_dataset.count_rows())} images.")
print(f"Got {len(embeddings)} embeddings.")

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

0it [00:00, ?it/s]

🗡️ 🐟 Project: 00:00 

🗡️ 🐟 Filter: 00:00 

🗡️ 🐟 Project: 00:00 

64it [00:30,  2.10it/s]

Shape of embedding for one batch: (64, 1280)


1984it [00:43, 156.16it/s][W227 12:39:47.181416818 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
[W227 12:39:47.184368241 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
2048it [00:43, 47.49it/s] 


Got 2000 embeddings.
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.54%     234.886ms       100.00%       43.286s       43.286s           0 b      -2.28 Gb             1  
                  _compile.compile_inner (dynamo_timed)         6.74%        2.916s        51.17%       22.149s       22.149s           0 b           0 b             1  
          OutputGraph.call_user_compiler (dynamo_timed)         0.15%      66.495ms        31.39%       13.587s       13.587s    

In [11]:
prof.export_chrome_trace(f"daftiter_trace_{MODEL_NAME}_{NUM_TEST_IMAGES}x{BATCH_SIZE}.json")