In [None]:
#| default_exp embed_images

In [None]:
#| export

import gc
import sys
import shutil
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory
from PIL import Image
import numpy as np
import torch
from datasets import IterableDataset
from tqdm import tqdm
from loguru import logger
from transformers import pipeline
from itertools import batched


In [None]:
# easy timestamps
logger.remove()
logger.add(sys.stdout, level="INFO")

1

### 1. Set variables for test

In [None]:
#| export

BATCH_SIZE = 4
MODEL_NAME = "timm/vit_small_patch14_reg4_dinov2.lvd142m"


In [None]:
#| export

def images_from_paths(pathlist):
    return (Image.open(p.as_posix()).convert("RGB").copy() for p in pathlist)

In [None]:
#| export

def embed_images(imagepaths: list[Path], model_name: str, batch_size: int
                 ) -> list[np.array]:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pipe = pipeline(task="image-feature-extraction",
                    model=MODEL_NAME, device=DEVICE, pool=True, use_fast=True)


    # logger.info("Starting embedding pipeline.")

    embeddings = []

    for out in tqdm(pipe(images_from_paths(imagepaths), batch_size=BATCH_SIZE),
                    total=len(imagepaths)//BATCH_SIZE):
        embeddings += out

    # logger.info("Done with embedding pipeline.")

    return embeddings