In [1]:
# |default_exp daft_embeddings

In [2]:
# | hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export

import gc
import shutil
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable

import daft
import numpy as np
import timm
import torch
from datasets import load_dataset
from einops import rearrange
from PIL import Image
from tqdm import tqdm

In [4]:
#| export


BATCH_SIZE = 16
MODEL_NAME = "vit_large_patch14_reg4_dinov2.lvd142m"
IMAGE_GLOB = None
IMAGES_FOLDER = "./tmp-test-images"

TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 2_000

MEMORY_BYTES = int(6e9) # 6 GB memory allocation
daft.set_execution_config(enable_native_executor=True,
                          default_morsel_size=BATCH_SIZE
                          )

nice_models = [
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m",
"aimv2_large_patch14_224.apple_pt_dist"
]

[
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m"
]

In [5]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None, max_images: int = 50) -> None:
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        img.save(dir / f"{i}.png")

    del dataset
    gc.collect()

    return None

In [6]:
tmp_path = Path(IMAGES_FOLDER)
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dl_hf_images(dir=tmp_path, max_images=NUM_TEST_IMAGES)

100%|██████████| 2000/2000 [00:35<00:00, 55.73it/s] 


In [7]:
@dataclass
class Embedder:
    model_name: str
    device: torch.device = field(default_factory=lambda: torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model: torch.nn.Module = field(init=False)
    transform: callable = field(init=False)

    def __post_init__(self):
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        self.model.eval()
        self.model.to(self.device, memory_format=torch.channels_last)
        self.model = torch.compile(self.model, dynamic=True)
        # Resolve config removes unneeded fields before create_transform
        cfg = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = timm.data.create_transform(**cfg)

    @torch.inference_mode()
    def embed(self, batch_imgs: torch.Tensor) -> torch.Tensor:
        """
        Given a batch of pre-transformed images, compute pooled embeddings.
        The batch is moved to the proper device (with channels_last format) and processed in inference mode.
        """
        batch_imgs = batch_imgs.to(self.device, non_blocking=True)
        batch_imgs = batch_imgs.contiguous(memory_format=torch.channels_last)
        if self.device.type == "cuda":
            with torch.amp.autocast("cuda"):
                return self.model(batch_imgs)
        else:
            # autocast can be comically slow for some CPU setups (PyTorch issue #118499)
            return self.model(batch_imgs)

In [8]:
#| export

@daft.udf(return_dtype=daft.DataType.list(daft.DataType.float32()))
class EmbedImageCol:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str, batch_size: int = 4):
        self.model_name = model_name
        self.batch_size = batch_size
        self.embedder = Embedder(self.model_name)

    def _pad_to_batch_size(batch, batch_size):
        """Pads the batch to batch_size with zeros"""
        orig_size = len(batch)
        batch = torch.stack(batch)
        if orig_size < batch_size:
            pad_tensor = torch.zeros((batch_size - orig_size, *batch.shape[1:]),
                                    dtype=batch.dtype, device=batch.device)
            batch = torch.cat([batch, pad_tensor], dim=0)
        return batch, orig_size

    def __call__(self, batch_images) -> torch.Tensor:
        ### this needs to be lazy -- something like a dataloader
        ### it's currently going to load the whole image_col
        images = [self.embedder.transform(im) for im in batch_images.to_pylist()]
        images = torch.cat(images, dim=0)
        images, orig_size = self._pad_to_batch_size(images, self.batch_size)
        embeddings = self.embedder.embed(images).cpu().numpy()
        return embeddings[:orig_size]

## Test it out

In [9]:
#| hide

glob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(glob).with_column_renamed("path", "path_full_img")

images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null"))



In [10]:
#| hide

EmbedImageColWithModel = EmbedImageCol.with_init_args(model_name=MODEL_NAME, batch_size=BATCH_SIZE
                                                      ).override_options(batch_size=BATCH_SIZE)
images_df = images_df.where(images_df["image"].not_null()
                            ).with_column("embed", EmbedImageColWithModel(daft.col("image"))
                            ).exclude("image", "num_rows")

In [12]:
images_df.explain(show_all=True)

== Unoptimized Logical Plan ==

* Project: col(path_full_img), col(size), col(embed)
|
* Project: col(path_full_img), col(size), col(num_rows), col(image),
|     py_udf(col(image)) as embed
|
* Filter: not_null(col(image))
|
* Project: col(path_full_img), col(size), col(num_rows),
|     image_decode(download(col(path_full_img))) as image
|
* Project: col(path) as path_full_img, col(size), col(num_rows)
|
* Source:
|   Number of partitions = 1
|   Output schema = path#Utf8, size#Int64, num_rows#Int64


== Optimized Logical Plan ==

* Project: col(path_full_img), col(size), py_udf(col(image)) as embed
|   Stats = { Approx num rows = 1,800, Approx size bytes = 95.93 KiB, Accumulated
|     selectivity = 0.90 }
|
* Filter: not_null(col(image))
|   Stats = { Approx num rows = 1,800, Approx size bytes = 95.93 KiB, Accumulated
|     selectivity = 0.90 }
|
* Project: col(path) as path_full_img, col(size), image_decode(download(col(path)
|     as path_full_img)) as image
|   Stats = { Approx num

In [12]:
#| hide

images_df = images_df.collect()

🗡️ 🐟 Filter: 00:00 

🗡️ 🐟 Project: 00:00 

🗡️ 🐟 Project: 00:00 

: 

In [None]:
#| hide

shutil.rmtree(tmp_path)