In [1]:
# |default_exp daft_embeddings

In [2]:
# | hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export

import daft
import timm
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from dataclasses import dataclass, field
from typing import Callable

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
@dataclass
class TimmEmbedder:
    """
    embed an image with any timm model that supports this

    Reference: https://huggingface.co/docs/timm/main/en/feature_extraction#pooled
    """
    model_name: str
    device: torch.device = field(init=False)
    dtype: torch.dtype = field(init=False)
    _model: Callable = field(init=False)
    _transforms: Callable = field(init=False)


    def __post_init__(self):
        # initialize model and transforms
        self._model = timm.create_model(self.model_name, pretrained=True, num_classes=0)
        cfg = self._model.pretrained_cfg
        self._transform = timm.data.create_transform(**timm.data.resolve_data_config(cfg))

        # set device and dtype
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.bfloat16 if self.device.type == "cuda" and torch.cuda.is_bf16_supported() else (
            torch.float16 if self.device.type == "cuda" else torch.float32
        )
        print(f"Inference device: {self.device} with dtype: {self.dtype}")

        # optimize model for inference
        self._model = self._model.eval().to(device=self.device, dtype=self.dtype)
        self._model = torch.jit.optimize_for_inference(torch.jit.script(self._model))


    def __call__(self, image: torch.Tensor | Image.Image) -> np.array:
            """transform image, run inference, extract embedding as 1D array"""
            image = self._transform(image).to(self.device, self.dtype).unsqueeze(0)
            emb = self._model(image)
            return emb.detach().cpu().float().numpy().squeeze()


In [5]:
#| hide

# it's a one-stop shop
e = TimmEmbedder("mobilenetv3_large_100")
e(torch.rand((3,256,256)))

Inference device: cpu with dtype: torch.float32


array([-0.17605317,  0.0147213 , -0.05652565, ..., -0.1414647 ,
        0.16466515,  0.0450598 ], dtype=float32)

In [6]:
#| export

@daft.udf(return_dtype=daft.DataType.tensor(daft.DataType.float32()))
class EmbedImages:
    """run timm embedder on an image column"""
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.embedder = TimmEmbedder(self.model_name)

    def __call__(self, images_col):
        images = images_col.to_pylist()
        print(f"First type: {type(images[0])}")
        print(f"First shape: {images[0].shape}")
        ## TODO: Modify to work on an array instead of a list
        ## Example: https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/mnist.ipynb
        # return [self.embedder(torch.tensor(image)) for image in images]
        return [np.random.random((512,)) for image in images]

In [None]:
#| hide

im_cols = daft.from_glob_path("../../../atlas-compare/data/orig_filt/*.png"
                              ).with_column_renamed("path", "path_full_img")

im_cols = im_cols.sample(0.01)

im_cols = im_cols.with_column("image", daft.col("path_full_img"
                                 ).url.download().image.decode(
                                     mode="RGB", on_error="null"))

MODEL_NAME = "mobilenetv3_large_100"
EmbedImagesWithModel = EmbedImages.with_init_args(MODEL_NAME)

im_cols = im_cols.with_column("embed", EmbedImagesWithModel(daft.col("image")))



In [8]:
im_cols.show(1)

Inference device: cpu with dtype: torch.float32
First type: <class 'numpy.ndarray'>
First shape: (157, 132, 3)


thread 'Compute-Thread-15' panicked at src/daft-core/src/array/ops/cast.rs:2192:18:
not implemented: List casting not implemented for dtype: Tensor(Float32)
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
--- PyO3 is resuming a panic after fetching a PanicException from Python. ---
Python stack trace below:


PanicException: not implemented: List casting not implemented for dtype: Tensor(Float32)

Error when running pipeline node Project


DaftCoreException: DaftError::External task 14177 panicked with message "not implemented: List casting not implemented for dtype: Tensor(Float32)"