In [1]:
import gc
import sys
import shutil
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory

import daft
import numpy as np
import timm
import torch
from datasets import load_dataset
from humanize import naturalsize
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from loguru import logger

In [2]:
# easy timestamps
logger.remove()
logger.add(sys.stdout, level="INFO")

1

## Test daft concurrency levels

Adapt [test 21 notebook (torch vs. daft-to-torch)](https://github.com/fr1ll/bedmap-dev/blob/clearer-atlas/nbs/daft-try/21_compare-daft-to-torch-data_clean.ipynb):
- modify concurrency level for daft UDF used in embedding loop

So far:
- None: Same as without setting `with_concurrency_level` (makes sense as this is default)
- 1 or 2: fails with error:
`AttributeError: Can't get local object '_ensure_registered_super_ext_type.<locals>.DaftExtension'`
    - This then triggers a memray error. Without memray, get same AttributeError.

### 1. Set variables for test

In [3]:
USE_DAFT: bool = True # else use torch dataset

BATCH_SIZE = 32
MODEL_NAME = "vit_small_patch14_reg4_dinov2.lvd142m"
TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 256

## 2. Define way to download small test dataset 

In [4]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None,
                 max_images: int = 64,
                 overwrite: bool = True,
                 format: str = "png") -> None:

    dataset = load_dataset(dataset_name, split="train", streaming=True)
    if overwrite:
        shutil.rmtree(dir, ignore_errors=True)
        dir.mkdir(parents=True, exist_ok=True)

    image_paths = []
    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        image_paths += [(dir / f"{i}.{format}")]
        img.save(image_paths[-1])

    print(f"Size of images on disk: {naturalsize(sum([p.stat().st_size for p in image_paths]))}")

    del dataset
    gc.collect()

    return None

## 3. Define timm-based embedding model

In [5]:

class Embedder:
    """instantiate pretrained timm model to generate embeddings"""
    def __init__(self, model_name: str, device: torch.device = None):
        self.model_name = model_name
        # choose device and dtype
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.device.type == "cuda":
            self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            self.dtype = torch.float32

        # Create and prepare the model
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model.eval()
        self.model = torch.compile(self.model, dynamic=True, mode="reduce-overhead")

        # must resolve config to drop unneeded fields
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """set up input and embed it"""
        batch_imgs = batch_imgs.to(self.device, non_blocking=True, memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda", dtype=self.dtype):
                return self.model(batch_imgs)
        return self.model(batch_imgs)

## 4. Define two types of datasets

In [6]:
@daft.udf(return_dtype=daft.DataType.python())
class TransformImagesDaft:
    """run timm embedder on an image column"""
    def __init__(self, transform: callable):
        self.transform = transform

    def __call__(self, batch_images) -> list:
        return [self.transform(Image.fromarray(im)) for im in batch_images.to_pylist()]

In [7]:

class TorchImageIterAsDict(Dataset):
    def __init__(self, filelist: list[Path], transform: callable):
        self.filelist = filelist
        self.transform = transform

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx: int):
        image = Image.open(self.filelist[idx]).convert("RGB")
        if self.transform:
            return {"image_transformed": self.transform(image)}
        # return as dict for easy comparison vs. daft
        else:
            return {"image": [image]}


In [8]:
def daft_to_torch_iter_from_glob(image_glob: str, transform: callable,
                                 num_concurr: int | None):
    """generate a torch image dataset via daft from a glob"""

    images_df = daft.from_glob_path(image_glob)
    images_df = images_df.with_column("image", daft.col("path"
                                    ).url.download().image.decode(
                                        mode="RGB", on_error="null")
                                    )
    images_df = images_df.where(images_df["image"].not_null())
    TransformImForModel = TransformImagesDaft.with_init_args(transform=transform).with_concurrency(num_concurr)
    images_df = images_df.with_column("image_transformed", TransformImForModel(daft.col("image"))
                                    ).exclude("image", "num_rows")
    return images_df.to_torch_iter_dataset()

In [9]:
def torch_iter_from_glob(image_glob: str, transform: callable):
    """generate a torch image dataset via daft from a glob"""

    image_list = [Path(p) for p in glob(image_glob)]
    return TorchImageIterAsDict(image_list, transform)

## 5. Embedding computation pipeline including dataset instantiation

In [10]:
def compute_embeddings(model_name: str, images_glob: str, batch_size: int = BATCH_SIZE,
                       dataset_type: str = "plain_torch", daft_nconcurr: int | None = 1
                       ) -> list[np.ndarray]:
    """
    Return a list of embeddings from a glob of images.
    Uses a timm pretrained model to generate embeddings
    """
    logger.info("Instantiating embedding model.")
    embedder = Embedder(model_name=model_name)

    logger.info(f"Creating dataset of type {dataset_type}.")
    if dataset_type == "daft_to_torch":
        dataset = daft_to_torch_iter_from_glob(images_glob, embedder.transform, daft_nconcurr)
    elif dataset_type == "plain_torch":
        dataset = torch_iter_from_glob(images_glob, embedder.transform)
    else:
        raise ValueError("Dataset type must be `daft_to_torch` or `plain_torch`.")

    logger.info("Creating dataloader.")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    embeddings = []

    logger.info("Generating embeddings.")
    for i, batch_images in enumerate(tqdm(dataloader, unit_scale=BATCH_SIZE)):
        emb = embedder.embed(batch_images["image_transformed"]).cpu().numpy()
        # if i == 0:
        #     print(f"Shape of embedding for one batch: {emb.shape}")
        embeddings.append(emb)
    logger.info("Stacking embeddings.")
    embeddings = np.vstack(embeddings)

    return embeddings

## 6. Do memory profiling with one type of dataset

Results written near top of notebook

In [11]:
%load_ext memray

In [None]:
%%memray_flamegraph --native --follow-fork --temporal

USE_DAFT = True
CONCURRENCY_LEVEL = 1
ds_type = "daft_to_torch" if USE_DAFT else "plain_torch"

with TemporaryDirectory() as tmp:
    logger.info("Downloading test images.")
    dl_hf_images(dir=Path(tmp), max_images=NUM_TEST_IMAGES)
    imglob = tmp+"/*.png"
    logger.info("Starting embedding pipeline.")
    embeddings = compute_embeddings(model_name=MODEL_NAME,
                                    images_glob = imglob,
                                    batch_size=BATCH_SIZE,
                                    dataset_type=ds_type,
                                    daft_nconcurr=CONCURRENCY_LEVEL)
    logger.info("Done with embedding pipeline.")

[32m2025-03-16 13:55:16.892[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mDownloading test images.[0m


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 256/256 [00:14<00:00, 17.50it/s]


Size of images on disk: 19.5 MB
[32m2025-03-16 13:55:34.684[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mStarting embedding pipeline.[0m
[32m2025-03-16 13:55:34.685[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_embeddings[0m:[36m8[0m - [1mInstantiating embedding model.[0m
[32m2025-03-16 13:55:37.441[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_embeddings[0m:[36m11[0m - [1mCreating dataset of type daft_to_torch.[0m
[32m2025-03-16 13:55:37.520[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_embeddings[0m:[36m19[0m - [1mCreating dataloader.[0m
[32m2025-03-16 13:55:37.522[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_embeddings[0m:[36m24[0m - [1mGenerating embeddings.[0m


0it [00:00, ?it/s]

üó°Ô∏è üêü Project: 00:00 

üó°Ô∏è üêü Filter: 00:00 

üó°Ô∏è üêü ActorPoolProject: 00:00 

Error when running pipeline node ActorPoolProject
0it [00:00, ?it/s]


AttributeError: Can't get local object '_ensure_registered_super_ext_type.<locals>.DaftExtension'

Output()

Memray ERROR: Invalid record subtype


Output()

KeyboardInterrupt: 