In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import numpy.typing as npt
import re


In [None]:
def plot_gallery(images: list[npt.NDArray], masks: list[npt.NDArray], filenames: list[str], pair_cols:int=3):
    assert len(images) == len(masks), "Number of images and masks must be the same"
    cols = pair_cols * 2
    rows = len(images) // pair_cols + 1
    # plot each image with its mask side by side in a grid
    fig, ax = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
    for i in range(len(images)):
        image = images[i]
        mask = masks[i]
        filename = filenames[i]
        ax[i // pair_cols, (i % pair_cols) * 2].imshow(image)
        ax[i // pair_cols, (i % pair_cols) * 2].set_title(filename)
        ax[i // pair_cols, (i % pair_cols) * 2].axis("off")
        ax[i // pair_cols, (i % pair_cols) * 2 + 1].imshow(mask)
        ax[i // pair_cols, (i % pair_cols) * 2 + 1].set_title(f"{filename} mask")
        ax[i // pair_cols, (i % pair_cols) * 2 + 1].axis("off")
    plt.tight_layout()
    plt.show()




In [None]:
base_dir = Path("../")
all_data_dir = base_dir / "data" / "all_data"
assert all_data_dir.exists(), f"Directory {all_data_dir} does not exist"

images = all_data_dir.glob("image_*.png")
masks = all_data_dir.glob("mask_*.png")

# Sort the lists by the numbe in the filename
images = sorted(images, key=lambda x: int(re.search(r"\d+", x.name).group()))
masks = sorted(masks, key=lambda x: int(re.search(r"\d+", x.name).group()))
filenames = [img.name for img in images]

print(f"Found {len(images)} images and {len(masks)} masks")
print(f"Images: {images}")
print(f"Masks: {masks}")

In [None]:
# plot gallery of images and masks
images_list = [plt.imread(image) for image in images]
masks_list = [plt.imread(mask) for mask in masks]
plot_gallery(images_list, masks_list, filenames, pair_cols=3)