In [None]:
#| default_exp embed_images

In [None]:
#| hide

%load_ext autoreload
%autoreload 2

In [None]:
#| export

from pathlib import Path
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
from loguru import logger
from transformers import pipeline
import daft

from bedmap.config import Cfg
from bedmap.step import step

In [None]:
#| export

BATCH_SIZE = 4

In [None]:
#| export

def images_from_paths(pathlist):
    return (Image.open(p.as_posix()).convert("RGB").copy() for p in pathlist)

In [None]:
#| export

def embed_images(imagepaths : list[Path],
                 model_name : str = "timm/vit_small_patch14_reg4_dinov2.lvd142m",
                 batch_size : int = 4
                 ) -> list[np.array]:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pipe = pipeline(task="image-feature-extraction",
                    model=model_name, device=device, pool=True, use_fast=True)

    # logger.info("Starting embedding pipeline.")
    embeddings = []

    for out in tqdm(pipe(images_from_paths(imagepaths), batch_size=batch_size),
                    total=len(imagepaths)//batch_size):
        embeddings += out

    # logger.info("Done with embedding pipeline.")

    return np.array(embeddings)

In [None]:
# df = daft.from_pydict({"img_path": ["file://cae.", "file:////acea/"]})
# df = df.with_column("img_path_trim",
#                     df["img_path"].str.replace(pattern="^file:/+", replacement="/", regex=True)
#                    )
# df.select(daft.col("img_path_trim")).to_pylist()



In [None]:
#| export

@step(requires=["img_path"], provides=["embeddings"])
def create_embeddings_col(df: daft.DataFrame, model_name: str, batch_size: int) -> daft.DataFrame:
    """
    Embed images for a given dataframe.
    """
    ## daft encodes paths as URIs, always starting with file://
    df = df.with_column("img_path_nouri", df["img_path"].str.replace(
        pattern="^file://", replacement="", regex=True))
    paths = [Path(r["img_path_nouri"]) for r in df.select("img_path_nouri").to_pylist()]
    embeds = embed_images(paths, model_name=model_name, batch_size=batch_size)
    # fixed_size_list lets us use normal arrow methods to calculate length later
    embeds_type = daft.DataType.fixed_size_list(daft.DataType.float32(), embeds.shape[-1])
    embeds_series = daft.Series.from_numpy(embeds).cast(embeds_type)

    df_embs = daft.from_pydict({"embeddings": embeds_series,
    "img_path_nouri":  df.select("img_path_nouri").to_arrow()["img_path_nouri"]}
    )

    df = df.join(df_embs, on="img_path_nouri")

    return df.exclude("img_path_nouri")

In [None]:
#| hide

from bedmap.prepare_images import df_images_from_pattern
from bedmap.config import Cfg

cfg = Cfg()
TEST_DIR = "../tests/test-data/smithsonian_butterflies_10/jpgs"

df = df_images_from_pattern(TEST_DIR)
df = create_embeddings_col(df, model_name=cfg.model_name, batch_size=BATCH_SIZE)

  self._settings_warn_unused_config_keys(sources, self.model_config)


üó°Ô∏è üêü InMemorySource: 00:00 

üó°Ô∏è üêü Project: 00:00 

Device set to use cpu
12it [00:04,  2.73it/s]                      


üó°Ô∏è üêü InMemorySource: 00:00 

üó°Ô∏è üêü Project: 00:00 

In [None]:
#| hide

import nbdev; nbdev.nbdev_export()