In [None]:
# |default_exp daft_embeddings

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export

import daft
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from einops import rearrange
from dataclasses import dataclass, field
from typing import Callable
from tqdm import tqdm
import shutil
from functools import partial
from datasets import load_dataset

In [None]:
#| export

MAX_IMAGES = 500
BATCH_SIZE = 24
MODEL_NAME = "vit_large_patch14_reg4_dinov2.lvd142m"
MEMORY_BYTES = int(6e9) # 6 GB memory allocation

daft.set_execution_config(enable_native_executor=True,
                          default_morsel_size=BATCH_SIZE
                          )

nice_models = [
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m"
]

[
"mobilenetv3_large_100",
"vit_small_patch14_reg4_dinov2.lvd142m",
"vit_base_patch14_reg4_dinov2.lvd142m",
"vit_large_patch14_reg4_dinov2.lvd142m"
]

In [None]:
#| hide

tmp_path = Path("./tmp-test-images")
shutil.rmtree(tmp_path, ignore_errors=True)
tmp_path.mkdir(parents=True, exist_ok=True)

dataset_name = "kvriza8/microscopy_images"

dataset = load_dataset(dataset_name, split="train", streaming=True)

for i, example in enumerate(tqdm(iter(dataset))):  # Use iterator to avoid full load
    if i >= MAX_IMAGES:
        break
    image = example["image"]
    image.save(tmp_path / f"{i}.png")

500it [00:11, 41.93it/s] 


In [None]:
#| export

@dataclass
class TimmEmbedder:
    """
    embed an image with any timm model that supports this

    Reference: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled
    """
    model_name: str
    device: torch.device = field(init=False)
    dtype: torch.dtype = field(init=False)
    _model: Callable = field(init=False)
    _transforms: Callable = field(init=False)
    _amp_autocast: Callable = field(default=None, init=False, repr=False)
    _instance: "TimmEmbedder" = field(default=None, init=False, repr=False)

    def __new__(cls, model_name):
        """make a singleton"""
        if not hasattr(cls, "_instance") or cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __post_init__(self):
        if hasattr(self, "_model"):  # Avoid re-initialization
            return
        # initialize model and transforms
        self._model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        cfg = self._model.pretrained_cfg
        self._transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

        # set device and dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.bfloat16 if self.device.type == "cuda" and torch.cuda.is_bf16_supported() else (
            torch.float16 if self.device.type == "cuda" else torch.float32
        )
        _amp_dtype = torch.bfloat16 if self.dtype == torch.bfloat16 else torch.float16
        self._amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=_amp_dtype)
        print(f"Inference device: {self.device} with dtype: {self.dtype}")

        # optimize model for inference
        self._model = self._model.eval().to(device=self.device, dtype=self.dtype)
        self._model = self._model.to(memory_format=torch.channels_last)
        self._model = torch.compile(self._model)
        

    def prepare(self, image: torch.Tensor | Image.Image) -> torch.Tensor:
         """transform image"""
         return self._transform(image).unsqueeze(0).to(self.device, self.dtype,
                                         memory_format=torch.channels_last)


    def embed(self, img_tensor: torch.Tensor) -> torch.Tensor:
        """embed tensor of already-transformed images"""
        with self._amp_autocast():
            return self._model(img_tensor).detach().cpu().float()

            
    def __call__(self, image: torch.Tensor | Image.Image) -> np.array:
        """transform image, run inference, extract embedding as 1D array"""
        image = self.prepare(image)
        return self.embed(image)
            

In [None]:
#| hide

# it's a one-stop shop
e = TimmEmbedder("mobilenetv3_large_100")
e(torch.rand((3,256,256)))

# or do in two steps

t = e.prepare(torch.rand((3,256,256)))
v = e.embed(t)

Inference device: cpu with dtype: torch.float32


In [None]:
#| export

@daft.udf(return_dtype=daft.DataType.list(daft.DataType.float32()),
          memory_bytes=MEMORY_BYTES)
class EmbedImages:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str, batch_size: int = 4):
        self.model_name = model_name
        self.batch_size = batch_size
        self.embedder = TimmEmbedder(self.model_name)
        
    def __call__(self, images_col) -> list:
        images = [rearrange(im, "h w c -> c h w") for im in images_col.to_pylist()]
        images = torch.cat([self.embedder.prepare(torch.tensor(im,
                                dtype=self.embedder.dtype, device=self.embedder.device)
                            ) for im in images],
                           dim=0)
        ## Example: https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/mnist.ipynb
        return torch.cat([self.embedder.embed(batch) for batch in torch.split(
                            images, self.batch_size, dim=0)],
                        dim=0).numpy()

## TODO: I need to figure out how to create two udfs that work together

## Test it out

In [None]:
#| hide

glob = tmp_path.as_posix() +"/*.png"
images_df = daft.from_glob_path(glob).with_column_renamed("path", "path_full_img")

images_df = images_df.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null"))



In [None]:
#| hide

EmbedImagesWithModel = EmbedImages.with_init_args(model_name=MODEL_NAME, batch_size=BATCH_SIZE)
images_df = images_df.with_column("embed", EmbedImagesWithModel(daft.col("image")))

In [None]:
#| hide

images_df = images_df.collect()

🗡️ 🐟 Project: 00:00 

In [None]:
#| hide

shutil.rmtree(tmp_path)