In [None]:
# |default_exp daft_embeddings

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export

import daft
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from einops import rearrange
from dataclasses import dataclass, field
from typing import Callable
from tqdm import tqdm

from datasets import load_dataset

In [None]:
#| export

daft.set_execution_config(enable_native_executor=True,
                        #   default_morsel_size=50
                          )

MAX_IMAGES = 50

In [None]:
#| hide

tmp_path = Path("./tmp-test-images")
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"

dataset = load_dataset(dataset_name, split="train", streaming=True)

for i, example in enumerate(tqdm(iter(dataset))):  # Use iterator to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")

50it [00:04, 11.35it/s]


In [None]:
#| export

@dataclass
class TimmEmbedder:
    """
    embed an image with any timm model that supports this

    Reference: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled
    """
    model_name: str
    device: torch.device = field(init=False)
    dtype: torch.dtype = field(init=False)
    _model: Callable = field(init=False)
    _transforms: Callable = field(init=False)
    _instance: "TimmEmbedder" = field(default=None, init=False, repr=False)

    def __new__(cls, model_name):
        """make a singleton"""
        if not hasattr(cls, "_instance") or cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __post_init__(self):
        if hasattr(self, "_model"):  # Avoid re-initialization
            return
        # initialize model and transforms
        self._model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        cfg = self._model.pretrained_cfg
        self._transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

        # set device and dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.bfloat16 if self.device.type == "cuda" and torch.cuda.is_bf16_supported() else (
            torch.float16 if self.device.type == "cuda" else torch.float32
        )
        print(f"Inference device: {self.device} with dtype: {self.dtype}")

        # optimize model for inference
        self._model = self._model.eval().to(device=self.device, dtype=self.dtype)
        self._model = torch.compile(self._model)


    def __call__(self, image: torch.Tensor | Image.Image) -> np.array:
            """transform image, run inference, extract embedding as 1D array"""
            image = self._transform(image).to(self.device, self.dtype).unsqueeze(0)
            emb = self._model(image)
            return emb.detach().cpu().float().numpy().squeeze()

    def transform(self, image: torch.Tensor | Image.Image) -> torch.Tensor:
         """transform image"""
         return self._transform(image).to(self.device, self.dtype).unsqueeze(0)

    def embed(self, img_tensor: torch.Tensor) -> np.array:
         """embed tensor of already-transformed images"""
         return self._model(img_tensor).detach().cpu().float().numpy().squeeze()


In [None]:
#| hide

# it's a one-stop shop
e = TimmEmbedder("mobilenetv3_large_100")
e(torch.rand((3,256,256)))

# or do in two steps

t = e.transform(torch.rand((3,256,256)))
v = e.embed(t)

Inference device: cpu with dtype: torch.float32


In [None]:
#| export

@daft.udf(return_dtype=daft.DataType.list(daft.DataType.float32()),
          memory_bytes=int(6e9))
class EmbedImages:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = TimmEmbedder(self.model_name)
        
    def __call__(self, images_col) -> list:
        images = [rearrange(im, "h w c -> c h w") for im in images_col.to_pylist()]
        images = torch.stack([self.embedder.transform(
                    torch.tensor(im, dtype=self.embedder.dtype,
                        device=self.embedder.device)) for im in images])
        ## Example: https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/mnist.ipynb
        return self.embedder.embed(images)

## TODO: I need to figure out how to create two udfs that work together

## Test it out

In [None]:
#| hide

glob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(glob).with_column_renamed("path", "path_full_img")

images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null"))

## inference time for some embedding models on CPU on my laptop

model_name | 50 images | 500 images | 1000 images | 2000 images
---------- | --------- | ---------- | ----------- | -----------
mobilenetv3_large_100 | 0m02s |  0m8s | 0m12s | 0m25s
vit_base_patch14_reg4_dinov2.lvd142m | 1m20s | 11m48s | 23m29s | OOM
vit_large_patch14_reg4_dinov2.lvd142m | 4m04s | 38m30s | OOM | OOM

In [None]:
#| hide

EmbedImagesWithModel = EmbedImages.with_init_args("mobilenetv3_large_100")

images_df = images_df.with_column("embed", EmbedImagesWithModel(daft.col("image")))

In [None]:
#| hide

images_df = images_df.collect()

In [None]:
#| hide

import shutil
shutil.rmtree(tmp_path)