In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
from monai.data import create_test_image_2d
from kidney.datasets.kaggle import get_reader, SampleType

In [None]:
reader = get_reader()

In [None]:
key = reader.get_keys(SampleType.Labeled)[0]

In [None]:
sample = reader.fetch_one(key)

In [None]:
sample.keys()

In [None]:
sample["image"].shape, sample["mask"].shape

In [None]:
img, seg = create_test_image_2d(256, 256, num_seg_classes=3, num_objs=12, noise_max=0.3, rad_max=50)

In [None]:
plt.imshow(img, cmap="gray")

In [None]:
m1, m2, m3 = [(seg == i).astype(np.uint8) for i in range(1, 4)]

In [None]:
plt.imshow(m1, cmap="gray")

In [None]:
from typing import List, Tuple

In [None]:
def overlay_masks(
    image: np.ndarray,
    masks: List[Tuple[np.ndarray, Tuple[int, int, int]]],
    convert_to_uint: bool = True
):
    assert image.ndim == 3
    assert convert_to_uint or image.dtype == np.uint8    

    _verify_overlay_masks_input(image, masks)
    
    image = image.astype(np.uint8) if convert_to_uint else image
    base = image.copy()
    for mask, color in masks:
        image[mask == 1] = color
    overlayed = cv.addWeighted(base, 0.5, image, 0.5, 0)
    return overlayed

        
def _verify_overlay_masks_input(image: np.ndarray, masks: List):
    for mask, color in masks:
        assert mask.ndim == 2
        assert mask.dtype == np.uint8
        assert mask.shape[:2] == image.shape[:2]
        assert len(color) == 3

In [None]:
u_img = (np.repeat(img[:, :, np.newaxis], 3, axis=-1) * 255).astype(np.uint8)

colored_masks = [
    (m1, (255, 0, 0)),
    (m2, (0, 255, 0)),
    (m3, (0, 0, 255))
]

plt.imshow(overlay_masks(u_img, colored_masks))