In [1]:
import os
# try not to run concurrent inference processes
# instead do batch inference with correct batch size
os.environ["DAFT_ENABLE_ACTOR_POOL_PROJECTIONS"]="1"
import gc
import sys
import shutil
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory

import daft
import numpy as np
import einops
import timm
import torch
from datasets import load_dataset
from humanize import naturalsize
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from loguru import logger

In [2]:
# easy timestamps
logger.remove()
logger.add(sys.stdout, level="INFO")

1

## Use daft for the end-to-end

Result: Peak memory usage drops from 3.0 Gb to 2.1 Gb. This is still more than plain torch at 1.9 Gb, but it's much much better

## TODO:
- [x] fix NoneType issue
    - fixed by... returning from `__call__`
- [ ] test using url.decode here
- [x] remove DAFT_ENABLE_ACTOR_POOL_PROJECTIONS=1
    - result: max 2.248 Gb
- [x] test setting morsel size to batch size and native executor = True
    - result: initializes 8 parallel UDFs. Max memory 3.7 Gb
- [x] test skipping embedder UDF
    - result: 1. Gb max memory, no NoneType error, 8 parallel UDF instances
- [x] test manually setting concurrency level to 1
    - result: `invalid record subtype` error
- [x] test native runner, fixed NoneType error, enable pool projections 1, morsel size is batch size
    - result tmpatnsi3en: ran 8.5 minutes, spikey memory graph, dropped over time, max 4 Gb
- [x] test embedder UDF has num_cpus=16 to see if that will avoid concurrency with native runner
    - result: waited between loading first embedder UDF and others..max mem 2.36 Gb, took 4 minutes to finish
- [x] test both UDFs have num_cpus=16
    - result: similar to previous one in log messages; max mem 2.4 Gb, took 3.5 min to finish
- [x] test doing load pillow within the embedding UDF to avoid pileup of images beforehand
    - result: max resident size 4 Gb, but diverges from max heap size <1Gb. Runs in 3m10s, but actual embedding pipeline 2min10s.
- [x] try same as above with native runner "false"
    - result: resident size 2.3 Gb, heap size 1.1 Gb, took about 2m16s for embedding pipeline itself
- [ ] try with native and remove morsel size parameter
- [ ] test (separate notebook) replicating the transform function outside the embedder class as a UDF to see if that's what causes memory duplication in notebook 28

### 1. Set variables for test

In [3]:
BATCH_SIZE = 32
MODEL_NAME = "vit_small_patch14_reg4_dinov2.lvd142m"
TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 256

# daft.set_execution_config(enable_native_executor=True)
daft.set_execution_config(default_morsel_size=BATCH_SIZE, enable_native_executor=True)

DaftContext(_ctx=<builtins.PyDaftContext object at 0x7fdd18877990>)

## 2. Define way to download small test dataset 

In [4]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None,
                 max_images: int = 64,
                 overwrite: bool = True,
                 format: str = "png") -> None:

    dataset = load_dataset(dataset_name, split="train", streaming=True)
    if overwrite:
        shutil.rmtree(dir, ignore_errors=True)
        dir.mkdir(parents=True, exist_ok=True)

    image_paths = []
    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        image_paths += [(dir / f"{i}.{format}")]
        img.save(image_paths[-1])

    logger.info(f"Size of images on disk: {naturalsize(sum([p.stat().st_size for p in image_paths]))}")

    del dataset
    gc.collect()

    return None

## 3. Define timm-based embedding model

In [5]:

class Embedder:
    """instantiate pretrained timm model to generate embeddings"""
    def __init__(self, model_name: str, device: torch.device = None):
        self.model_name = model_name
        # choose device and dtype
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.device.type == "cuda":
            self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            self.dtype = torch.float32

        # Create and prepare the model
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model.eval()
        self.model = torch.compile(self.model, dynamic=True, mode="reduce-overhead")

        # must resolve config to drop unneeded fields
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """set up input and embed it"""
        batch_imgs = batch_imgs.to(self.device, non_blocking=True, memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda", dtype=self.dtype):
                return self.model(batch_imgs)
        return self.model(batch_imgs)

## 4. Define two types of datasets

In [6]:
@daft.udf(return_dtype=daft.DataType.python(), batch_size=BATCH_SIZE, num_cpus=16)
class LoadPillow:
    """run path column as PIL Image"""
    def __init__(self):
        self._pil_load = lambda x: Image.open(x).convert("RGB")
        logger.info("Initializing Pillow loader UDF")

    def __call__(self, batch_paths: daft.Series) -> list:
        batch_paths = [p.replace("file://","") for p in batch_paths.to_pylist()]
        return [self._pil_load(p) for p in batch_paths]

In [7]:
@daft.udf(return_dtype=daft.DataType.python(), batch_size=BATCH_SIZE, num_cpus=16)
class DaftTimmEmbed:
    """transform and embed images with timm pre-trained model"""
    def __init__(self, model_name: str):
        self._embedder = Embedder(model_name)
        logger.info(f"Initialize embedder on device {self._embedder.device}")
        logger.info(f"with dtype {self._embedder.dtype}")


    def __call__(self, batch_images: daft.Series) -> list:
        batch_t = [self._embedder.transform(im) for im in batch_images.to_pylist()]
        batch_t = torch.stack(batch_t).to(memory_format=torch.channels_last)
        return list(self._embedder.embed(batch_t).cpu().numpy())

In [8]:
@daft.udf(return_dtype=daft.DataType.python(), batch_size=BATCH_SIZE, num_cpus=16)
class DaftTimmEmbedFromPath:
    """transform and embed images with timm pre-trained model"""
    def __init__(self, model_name: str):
        self._embedder = Embedder(model_name)
        self._pil_load = lambda x: Image.open(x).convert("RGB")
        logger.info(f"Initialize embedder on device {self._embedder.device}")
        logger.info(f"with dtype {self._embedder.dtype}")


    def __call__(self, batch_paths: daft.Series) -> list:
        batch_paths = [p.replace("file://","") for p in batch_paths.to_pylist()]
        batch_t = [self._embedder.transform(self._pil_load(p)) for p in batch_paths]
        batch_t = torch.stack(batch_t).to(memory_format=torch.channels_last)
        return list(self._embedder.embed(batch_t).cpu().numpy())

## 5. Embedding computation pipeline including dataset instantiation

In [9]:
def daft_glob_infer(image_glob: str,
                    batch_size: int = 32):
    """generate a torch image dataset via daft from a glob"""

    images_df = daft.from_glob_path(image_glob)
    DaftTimmEmbedwModel = DaftTimmEmbed.with_init_args(model_name=MODEL_NAME)

    images_df = images_df.with_column("image", LoadPillow(daft.col("path"))
                                    ).exclude("num_rows")
    images_df = images_df.where(images_df["image"].not_null())
    images_df = images_df.with_column("embedding",
                                      DaftTimmEmbedwModel(daft.col("image"))
                                      ).exclude("image")
    return images_df

In [10]:
def daft_glob_infer_oneudf(image_glob: str,
                    batch_size: int = 32):
    """generate a torch image dataset via daft from a glob"""

    images_df = daft.from_glob_path(image_glob)
    DaftTimmFromPathwModel = DaftTimmEmbedFromPath.with_init_args(model_name=MODEL_NAME)
    images_df = images_df.with_column("embedding",
                                      DaftTimmFromPathwModel(daft.col("path"))
                                      )
    return images_df

## 6. Do memory profiling with one type of dataset

Results written near top of notebook

In [11]:
%load_ext memray

In [12]:
%%memray_flamegraph --native --follow-fork --temporal

with TemporaryDirectory() as tmp:
    logger.info("Downloading test images.")
    dl_hf_images(dir=Path(tmp), max_images=NUM_TEST_IMAGES)
    imglob = tmp+"/*.png"
    logger.info("Starting embedding pipeline.")
    df_embeds = daft_glob_infer_oneudf(imglob)
    logger.info("Set up embedding dataframe")
    df_embeds = df_embeds.collect()
    logger.info("Done with embedding pipeline.")


[32m2025-03-25 23:10:10.254[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mDownloading test images.[0m


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 256/256 [00:09<00:00, 26.29it/s] 

[32m2025-03-25 23:10:22.720[0m | [1mINFO    [0m | [36m__main__[0m:[36mdl_hf_images[0m:[36m20[0m - [1mSize of images on disk: 19.5 MB[0m





[32m2025-03-25 23:10:22.890[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1mStarting embedding pipeline.[0m
[32m2025-03-25 23:10:22.936[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [1mSet up embedding dataframe[0m
[32m2025-03-25 23:10:25.418[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m7[0m - [1mInitialize embedder on device cpu[0m
[32m2025-03-25 23:10:25.420[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m8[0m - [1mwith dtype torch.float32[0m


üó°Ô∏è üêü Project: 00:00 

[32m2025-03-25 23:11:05.818[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m7[0m - [1mInitialize embedder on device cpu[0m
[32m2025-03-25 23:11:05.820[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m8[0m - [1mwith dtype torch.float32[0m
[32m2025-03-25 23:11:18.477[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m7[0m - [1mInitialize embedder on device cpu[0m
[32m2025-03-25 23:11:18.479[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m8[0m - [1mwith dtype torch.float32[0m
[32m2025-03-25 23:11:31.331[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m7[0m - [1mInitialize embedder on device cpu[0m
[32m2025-03-25 23:11:31.333[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m8[0m - [1mwith dtype torch.float32[0m
[32m2025-03-25 23:11:44.317[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m7[0m - [1mInitialize embedder on device cpu[0m
[32m2025-03

Output()

Output()