In [None]:
#| default_exp test_end_to_end

In [None]:
#| export

from tempfile import TemporaryDirectory
from datasets import load_dataset
from humanize import naturalsize
from pathlib import Path
import shutil
from tqdm.auto import tqdm
from PIL import Image
from loguru import logger



In [None]:
#| export

NUM_TEST_IMAGES = 256
TEST_DATASET = "kvriza8/microscopy_images"

# easy timestamps
logger.remove()
logger.add(sys.stdout, level="INFO")

In [None]:
#| export

def dl_hf_images(dataset_name: str = "kvriza8/microscopy_images",
                 dir: Path = None,
                 max_images: int = 64,
                 overwrite: bool = True,
                 format: str = "png") -> None:

    dataset = load_dataset(dataset_name, split="train", streaming=True)
    if overwrite:
        shutil.rmtree(dir, ignore_errors=True)
        dir.mkdir(parents=True, exist_ok=True)

    image_paths = []
    for i, img_row in enumerate(tqdm(iter(dataset), total=max_images)):
        if i >= max_images:
            break
        img = img_row["image"]
        image_paths += [(dir / f"{i}.{format}")]
        img.save(image_paths[-1])

    print(f"Size of images on disk: {naturalsize(sum([p.stat().st_size for p in image_paths]))}")

    return None

In [None]:
#| export

def images_from_paths(pathlist):
    return (Image.open(p.as_posix()).convert("RGB").copy() for p in pathlist)

In [None]:
#| export

with TemporaryDirectory() as tmp:
    logger.info("Downloading test images.")
    dl_hf_images(dir=Path(tmp), max_images=NUM_TEST_IMAGES)
    imagepaths = list(Path(tmp).glob("*.png"))
    gen = images_from_paths(imagepaths)
    for i in range(1):
        im = next(gen)
        im.show()