In [1]:
# |default_exp daft_embeddings

In [2]:
# | hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export

import daft
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from einops import rearrange
from dataclasses import dataclass, field
from typing import Callable
from tqdm import tqdm

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
#| export

daft.set_execution_config(enable_native_executor=True,
                        #   default_morsel_size=50
                          )

MAX_IMAGES = 2000

In [5]:
#| hide

tmp_path = Path("./tmp-test-images")
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"

dataset = load_dataset(dataset_name, split="train", streaming=True)

for i, example in enumerate(tqdm(iter(dataset))):  # Use iterator to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")

2000it [00:37, 52.88it/s] 


In [6]:
#| export

@dataclass
class TimmEmbedder:
    """
    embed an image with any timm model that supports this

    Reference: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled
    """
    model_name: str
    device: torch.device = field(init=False)
    dtype: torch.dtype = field(init=False)
    _model: Callable = field(init=False)
    _transforms: Callable = field(init=False)


    def __post_init__(self):
        # initialize model and transforms
        self._model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        cfg = self._model.pretrained_cfg
        self._transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

        # set device and dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.bfloat16 if self.device.type == "cuda" and torch.cuda.is_bf16_supported() else (
            torch.float16 if self.device.type == "cuda" else torch.float32
        )
        print(f"Inference device: {self.device} with dtype: {self.dtype}")

        # optimize model for inference
        self._model = self._model.eval().to(device=self.device, dtype=self.dtype)
        self._model = torch.jit.optimize_for_inference(torch.jit.script(self._model))


    def __call__(self, image: torch.Tensor | Image.Image) -> np.array:
            """transform image, run inference, extract embedding as 1D array"""
            image = self._transform(image).to(self.device, self.dtype).unsqueeze(0)
            emb = self._model(image)
            return emb.detach().cpu().float().numpy().squeeze()


In [7]:
#| hide

# it's a one-stop shop
# e = TimmEmbedder("mobilenetv3_large_100")
# e(torch.rand((3,256,256)))

In [8]:
#| export

@daft.udf(return_dtype=daft.DataType.list(daft.DataType.float32()))
class EmbedImages:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = TimmEmbedder(self.model_name)

    def __call__(self, images_col):
        images = [rearrange(im, "h w c -> c h w") for im in images_col.to_pylist()]
        ## Note: expect images are different sizes
        ## could maybe speed up by doing the resize transform separately,
        ## then doing batch inference
        ## Example: https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/mnist.ipynb
        return [self.embedder(
            torch.tensor(im, dtype=self.embedder.dtype)) for im in images]

## Test it out

In [9]:
#| hide

glob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(glob).with_column_renamed("path", "path_full_img")

images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null"))



## inference time for some embedding models on CPU on my laptop

model_name | 50 images | 500 images | 1000 images | 2000 images
---------- | --------- | ---------- | ----------- | -----------
mobilenetv3_large_100 | 0m02s |  0m8s | 0m12s | 0m25s
vit_base_patch14_reg4_dinov2.lvd142m | 1m20s | 11m48s | 23m29s | OOM
vit_large_patch14_reg4_dinov2.lvd142m | 4m04s | 38m30s | OOM | OOM

In [10]:
#| hide

MODEL_NAME = "mobilenetv3_large_100"
EmbedImagesWithModel = EmbedImages.with_init_args(MODEL_NAME)

images_df = images_df.with_column("embed", EmbedImagesWithModel(daft.col("image")))

In [11]:
#| hide

images_df = images_df.collect()

🗡️ 🐟 Project: 00:00 2,000 rows received, 0 rows emitted[AInference device: cpu with dtype: torch.float32
                                                            d

In [12]:
#| hide

import shutil
shutil.rmtree(tmp_path)