In [1]:
import gc
import sys
import shutil
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory
from PIL import Image
import numpy as np
import torch
from datasets import load_dataset
from humanize import naturalsize
from tqdm import tqdm
from loguru import logger
from transformers import pipeline
from itertools import batched


In [2]:
# easy timestamps
logger.remove()
logger.add(sys.stdout, level="INFO")

1

### 1. Set variables for test

In [3]:
BATCH_SIZE = 32
MODEL_NAME = "timm/vit_small_patch14_reg4_dinov2.lvd142m"
TEST_DATASET = "kvriza8/microscopy_images"
NUM_TEST_IMAGES = 256

## 2. Define way to download small test dataset 

In [4]:
def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None,
                 max_images: int = 64,
                 overwrite: bool = True,
                 format: str = "png") -> None:

    dataset = load_dataset(dataset_name, split="train", streaming=True)
    if overwrite:
        shutil.rmtree(dir, ignore_errors=True)
        dir.mkdir(parents=True, exist_ok=True)

    image_paths = []
    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        image_paths += [(dir / f"{i}.{format}")]
        img.save(image_paths[-1])

    print(f"Size of images on disk: {naturalsize(sum([p.stat().st_size for p in image_paths]))}")

    del dataset
    gc.collect()

    return None

In [5]:
%load_ext memray

In [8]:
%%memray_flamegraph --native --follow-fork --temporal

with TemporaryDirectory() as tmp:    
    logger.info("Downloading test images.")
    dl_hf_images(dir=Path(tmp), max_images=NUM_TEST_IMAGES)
    imglob = tmp+"/*.png"
    imagepaths = list(Path(tmp).glob("*.png"))
    embeddings = []
    batch = []

    logger.info("Instantiating pipeline.")
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pipe = pipeline(task="image-feature-extraction",
                    model=MODEL_NAME, device=DEVICE, pool=True, use_fast=True)
            

    logger.info("Starting embedding pipeline.")

    for batch_paths in batched(imagepaths, BATCH_SIZE):
        batch = [Image.open(p.as_posix()).convert("RGB") for p in batch_paths]
        embeddings += pipe(batch)

    logger.info("Done with embedding pipeline.")

[32m2025-03-29 13:21:02.770[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mDownloading test images.[0m


100%|██████████| 256/256 [00:07<00:00, 33.36it/s] 


Size of images on disk: 19.5 MB
[32m2025-03-29 13:21:11.540[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mInstantiating pipeline.[0m


Device set to use cpu


[32m2025-03-29 13:21:12.159[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mStarting embedding pipeline.[0m
[32m2025-03-29 13:23:13.773[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mDone with embedding pipeline.[0m


Output()

Output()