In [None]:
#|default_exp embeddings

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export

from pathlib import Path

import numpy as np
import torch
from PIL import Image
from tqdm.auto import tqdm
from transformers import pipeline

from clip_plot.utils import timestamp


# Create and/or load embeddings

In [None]:
#| export

def images_from_paths(pathlist):
    return (Image.open(p.as_posix()).convert("RGB").copy() for p in pathlist)

In [None]:
#| export

def images_iterator(ImageEngine):
    return (img.original for img in ImageEngine)

In [None]:
#| export

def embed_images(imagepaths : list[Path],
                 model_name : str = "timm/convnext_tiny.dinov3_lvd1689m",
                 batch_size : int = 4
                 ) -> np.ndarray:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device for inference: {device}")
    pipe = pipeline(task="image-feature-extraction",
                    model=model_name, device=device, pool=True, use_fast=True)

    print(timestamp(), f"Creating embeddings using {model_name}")
    embeddings = []
    imagepath_strs = [Path(p).as_posix() for p in imagepaths]

    # for out in tqdm(pipe(imagepath_strs, batch_size=batch_size), total=len(imagepath_strs)):
    #     embeddings += out
    for p in tqdm(imagepath_strs): # giving up on progress bar with batch size
        out = pipe(p)
        embeddings += out

    print(timestamp(), "Done creating embeddings.")

    return np.array(embeddings)

In [None]:
#| export

def get_embeddings(ImageEngine,
                   model_name : str = "timm/convnext_tiny.dinov3_lvd1689m",
                   batch_size : int = 4
                   ) -> np.ndarray:
    return embed_images(ImageEngine.image_paths, model_name=model_name, batch_size=batch_size)

In [None]:
#| export

def write_embeddings(embeddings : np.ndarray, names: list[str], dir: Path):
    """write out embeddings and return paths"""
    paths = [(dir/n).resolve() for n in names]
    for p, e in zip(paths, embeddings):
        np.save(p.with_suffix('.npy'), e)
    return paths


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()