In [1]:
import gc
import shutil
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import timm
import torch
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import daft

In [2]:

BATCH_SIZE = 32
MODEL_NAME = "vit_base_patch14_reg4_dinov2.lvd142m"
IMAGE_GLOB = None
IMAGES_FOLDER = "./tmp-test-images"

TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 500

nice_models = [
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m",
"aimv2_large_patch14_224.apple_pt_dist"
]

with vit_base_patch14 and torch dataloader:

num_images | batch_size | optimize | time |
-----------|------------|----------|------|
500        |         32 | False    | 10:07 
200        |         32 | False    | 04:50 
2000       |         32 | False    | 41:00 
50         |         32 | Static   | 01:14
50         |         32 | Dynamic  | 01:14
500        |         32 | Static   | 09:33 
2000       | 32 (fixed) | Static   | 36:22
2000       | 16 (fixed) | Static   | 36:32
2000       |  4 (fixed) | Static   | 39:17
2000       | 128 (fixd) | Static   | OOM
2000lrg    |  16 (fixd) | Static   | -

In [3]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None, max_images: int = 50) -> None:
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        img.save(dir / f"{i}.png")

    del dataset
    gc.collect()

    return None

In [4]:

tmp_path = Path(IMAGES_FOLDER)
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dl_hf_images(dir=tmp_path, max_images=NUM_TEST_IMAGES)


100%|██████████| 500/500 [00:10<00:00, 46.84it/s] 


In [5]:

@dataclass
class Embedder:
    model_name: str
    device: torch.device = field(default_factory=lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model: torch.nn.Module = field(init=False)
    transform: callable = field(init=False)

    def __post_init__(self):
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.eval()
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model = torch.compile(self.model, dynamic=True)
        # Resolve config removes unneeded fields before create_transform
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Given a batch of pre-transformed images, compute pooled embeddings.
        The batch is moved to the proper device (with channels_last format) and processed in inference mode.
        """
        batch_imgs = batch_imgs.to(self.device, non_blocking=True)
        batch_imgs = batch_imgs.contiguous(memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda"):
                return self.model(batch_imgs)
        else:
            # autocast can be comically slow for some CPU setups (PyTorch issue #118499)
            return self.model(batch_imgs)

In [6]:
@daft.udf(return_dtype=daft.DataType.python())
class TransformImageCol:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = Embedder(self.model_name)

    def __call__(self, batch_images) -> list:
        return [self.embedder.transform(Image.fromarray(im)) for im in batch_images.to_pylist()]

In [7]:
imglob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(imglob).with_column_renamed("path", "path_full_img")
images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null")
                                 )
images_df = images_df.where(images_df["image"].not_null())

TransformImForModel = TransformImageCol.with_init_args(model_name=MODEL_NAME)

images_df = images_df.with_column("image_transformed", TransformImForModel(daft.col("image"))
                                  ).exclude("image", "num_rows")

images_df.show(1)

path_full_img Utf8,size Int64,image_transformed Python
file://tmp-test-images/60.png,37758,"tensor([[[-0.7993, -0.7308, -0.6794, ..., 0.6734, 1.5810, 2.2489],  [-0.9363, -0.8678, -0.8164, ..., 0.6049, 1.5468, 2.2489],  [-1.2617, -1.2103, -1.1589, ..., 0.4851, 1.4783, 2.2318],  ...,  [-0.4739, -0.5596, -0.6623, ..., 0.3309, 1.4269, 2.2489],  [-0.7822, -0.9363, -1.0219, ..., 0.3309, 1.4269, 2.2489],  [-0.9363, -1.1075, -1.2103, ..., 0.3309, 1.4269, 2.2489]],  [[-0.6877, -0.6176, -0.5651, ..., 0.8179, 1.7458, 2.4286],  [-0.8277, -0.7577, -0.7052, ..., 0.7479, 1.7108, 2.4286],  [-1.1604, -1.1078, -1.0553, ..., 0.6254, 1.6408, 2.4111],  ...,  [-0.3550, -0.4426, -0.5476, ..., 0.4678, 1.5882, 2.4286],  [-0.6702, -0.8277, -0.9153, ..., 0.4678, 1.5882, 2.4286],  [-0.8277, -1.0028, -1.1078, ..., 0.4678, 1.5882, 2.4286]],  [[-0.4624, -0.3927, -0.3404, ..., 1.0365, 1.9603, 2.6400],  [-0.6018, -0.5321, -0.4798, ..., 0.9668, 1.9254, 2.6400],  [-0.9330, -0.8807, -0.8284, ..., 0.8448, 1.8557, 2.6226],  ...,  [-0.1312, -0.2184, -0.3230, ..., 0.6879, 1.8034, 2.6400],  [-0.4450, -0.6018, -0.6890, ..., 0.6879, 1.8034, 2.6400],  [-0.6018, -0.7761, -0.8807, ..., 0.6879, 1.8034, 2.6400]]])"


In [8]:

def compute_embeddings(model_name:
                       str, dataset: torch.utils.data.IterableDataset,
                       batch_size: int = BATCH_SIZE) -> list[np.ndarray]:
    """
    Given a model name and a filelist (list of image paths), this function computes and returns a list
    of embeddings (one per image). The function instantiates an Embedder, builds a dataset and dataloader,
    and processes images in batches.
    """
    embedder = Embedder(model_name=model_name)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    for i, batch_images in enumerate(tqdm(dataloader, unit_scale=BATCH_SIZE)):
        emb = embedder.embed(batch_images["image_transformed"]).cpu().numpy()

        if i == 0:
            embeddings = emb
            print(f"Shape of embedding for one batch: {emb.shape}")
        else:
            embeddings = np.concatenate((embeddings, emb), axis=0)

    return embeddings

In [9]:
from torch.profiler import profile, record_function, ProfilerActivity

images_dataset = images_df.to_torch_iter_dataset()

In [10]:


with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    with record_function("model_inference"):

        embeddings = compute_embeddings(MODEL_NAME, images_dataset, BATCH_SIZE)

# print(f"Processed {len(images_dataset.count_rows())} images.")
print(f"Got {len(embeddings)} embeddings.")

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

0it [00:00, ?it/s]

🗡️ 🐟 Project: 00:00 

🗡️ 🐟 Filter: 00:00 

🗡️ 🐟 Project: 00:00 

32it [00:47,  1.50s/it]

Shape of embedding for one batch: (32, 768)


480it [08:50,  1.07s/it][W227 14:03:37.814903742 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
512it [09:12,  1.08s/it]


Got 500 embeddings.
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.04%     224.783ms       100.00%      553.379s      553.379s           0 b      -3.33 Gb             1  
                             Torch-Compiled Region: 0/0         3.53%       19.516s        96.88%      536.134s       33.508s       1.46 Mb     -25.91 Gb            16  
                                            aten::addmm        64.72%      358.138s        68.68%      380.043s     395.878ms     

In [11]:
prof.export_chrome_trace(f"daftiter_trace_{MODEL_NAME}_{NUM_TEST_IMAGES}x{BATCH_SIZE}.json")