In [None]:
from pathlib import Path
import re

import numpy as np
import matplotlib.pyplot as plt


In [None]:
base_dir = Path("../")
data_dir = base_dir / "data" / "all_data"

images = list(data_dir.glob("image_*.npy"))
masks = list(data_dir.glob("mask_*.npy"))

# sort by the number in the filename
images.sort(key=lambda x: int(x.stem.split("_")[1]))
masks.sort(key=lambda x: int(x.stem.split("_")[1]))
# check that the images and masks are in the same order
for i, (img, mask) in enumerate(zip(images, masks)):
    img_num = int(re.search(r"image_(\d+)", img.stem).group(1))
    mask_num = int(re.search(r"mask_(\d+)", mask.stem).group(1))
    assert img_num == mask_num, f"Image and mask numbers do not match: {img_num} != {mask_num}"

# print image mask pairs
for img, mask in zip(images, masks):
    img_data = np.load(img)
    mask_data = np.load(mask)

    print(f"mask unique values: {np.unique(mask_data)} dtype: {mask_data.dtype}")

    # check that the image and mask have the same shape
    assert img_data.shape == mask_data.shape, f"Image and mask shapes do not match: {img_data.shape} != {mask_data.shape}"

    # plot the image and mask
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_data)
    ax[0].set_title(img.stem)
    ax[1].imshow(mask_data)
    ax[1].set_title(mask.stem)
    plt.show()