In [None]:
#| default_exp embed_images

In [None]:
#| hide
 
%load_ext autoreload
%autoreload 2

In [None]:
#| export

from pathlib import Path
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
from loguru import logger
from transformers import pipeline

In [None]:
#| export

def images_from_paths(pathlist):
    return (Image.open(p.as_posix()).convert("RGB").copy() for p in pathlist)

In [None]:
#| export

def embed_images(imagepaths : list[Path],
                 model_name : str = "timm/vit_small_patch14_reg4_dinov2.lvd142m",
                 batch_size : int = 4
                 ) -> list[np.array]:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pipe = pipeline(task="image-feature-extraction",
                    model=model_name, device=device, pool=True, use_fast=True)

    # logger.info("Starting embedding pipeline.")
    embeddings = []

    for out in tqdm(pipe(images_from_paths(imagepaths), batch_size=batch_size),
                    total=len(imagepaths)//batch_size):
        embeddings += out

    # logger.info("Done with embedding pipeline.")

    return np.array(embeddings)

In [None]:
#| hide

import nbdev; nbdev.nbdev_export()