In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import random
from collections import defaultdict
from pathlib import Path
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
import rasterio
from kidney.datasets.kaggle import get_reader, SampleType, DatasetReader
from zeus.utils import list_files
from zeus.plotting.utils import axes, calculate_layout

In [None]:
def print_dataset_info(reader: DatasetReader, sample_type: SampleType):
    reader = get_reader()
    train_keys = reader.get_keys(sample_type)
    identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
    for key in train_keys:
        meta = reader.fetch_meta(key)
        with rasterio.open(meta["tiff"], transform=identity) as dataset:
            height, width = shape = dataset.shape
            has_mask = "[trn]" if meta["mask"] is not None else "[tst]"
            print(has_mask, key, height, width, dataset.indexes)

In [None]:
print_dataset_info(get_reader(), SampleType.All)

In [None]:
PREPARED_DIR = "/mnt/fast/data/kidney/images_32_1024"

In [None]:
def read_png_images(folder: str):
    samples = defaultdict(dict)
    for fn in list_files(folder):
        image_type, image_id = Path(fn).stem.split(".")
        samples[image_id][image_type] = fn
        if image_type == "img":
            samples[image_id]["masked"] = False
            samples[image_id]["colored"] = colored_image(fn)
        if image_type == "seg":
            samples[image_id]["masked"] = True
            samples[image_id]["mask_image_ratio"] = non_zero_pixels_ratio(fn)
    return samples
        
def colored_image(filename: str) -> bool:
    image = PIL.Image.open(filename)
    return image.mode == "RGB"

def non_zero_pixels_ratio(filename: str) -> float:
    arr = np.asarray(PIL.Image.open(filename))
    return np.where(arr == 255, 1, 0).mean()

In [None]:
images = read_png_images(PREPARED_DIR)

In [None]:
def images_summary(images: Dict) -> pd.DataFrame:
    return pd.DataFrame([
        {
            "image_id": image_id,
            "masked": info["masked"],
            "colored": info["colored"],
            "ratio": info["mask_image_ratio"] if info["masked"] else np.nan,
        }
        for image_id, info in images.items()
    ])  

In [None]:
info = images_summary(images)

In [None]:
colored = info.query("colored")
colored_no_mask = colored.query("ratio == 0")
colored_small_mask = colored.query("ratio > 0 and ratio <= 0.05")
colored_medium_mask = colored.query("ratio > 0.05 and ratio <= 0.20")
colored_large_mask = colored.query("ratio > 0.20")

grayscale = info.query("not colored")
grayscale_no_mask = grayscale.query("ratio == 0")
grayscale_small_mask = grayscale.query("ratio > 0 and ratio <= 0.05")
grayscale_medium_mask = grayscale.query("ratio > 0.05 and ratio <= 0.20")
grayscale_large_mask = grayscale.query("ratio > 0.20")

image_groups = {
    "colored": {
        "empty": colored_no_mask.image_id.tolist(),
        "small": colored_small_mask.image_id.tolist(),
        "medium": colored_medium_mask.image_id.tolist(),
        "large": colored_large_mask.image_id.tolist(),
    },
    "grayscale": {
        "empty": grayscale_no_mask.image_id.tolist(),
        "small": grayscale_small_mask.image_id.tolist(),
        "medium": grayscale_medium_mask.image_id.tolist(),
        "large": grayscale_medium_mask.image_id.tolist(),
    }
}

In [None]:
info.query("colored").ratio.plot.hist(bins=20)

In [None]:
info.query("not colored").ratio.plot.hist(bins=20)

In [None]:
n = 7
for color, mask_groups in image_groups.items():
    for mask_size, image_ids in mask_groups.items():
        keys = random.sample(image_ids, k=n*n)
        canvas = axes(subplots=(n, n), figsize=(20, 20))
        for key, ax in zip(keys, canvas.flat):
            x = images[key]
            img = np.asarray(PIL.Image.open(x["img"]))
            seg = np.asarray(PIL.Image.open(x["seg"]))
            grayscale = img.ndim == 2
            ax.imshow(img, cmap="gray" if color == "grayscale" else None)
            ax.imshow(seg, alpha=0.3)
            ax.axis(False)
            # ax.set_title("grayscale" if grayscale else "colored")
        plt.gcf().suptitle(f"{color} ({mask_size})")

In [None]:
# x = samples["8242609fa_19584_10759_20608_11783"]
# img = np.asarray(PIL.Image.open(x["img"]))
# seg = np.asarray(PIL.Image.open(x["seg"]))
# plt.figure(figsize=(10,10))
# plt.imshow(img)
# plt.imshow(seg, alpha=0.3)
# plt.show()

In [None]:
# n = 7
# keys = random.sample(samples.keys(), k=n*n)
# canvas = axes(subplots=(n, n), figsize=(20, 20))
# for key, ax in zip(keys, canvas.flat):
#     x = samples[key]
#     img = np.asarray(PIL.Image.open(x["img"]))
#     seg = np.asarray(PIL.Image.open(x["seg"]))
#     grayscale = img.ndim == 2
#     ax.imshow(img, cmap="gray" if grayscale else None)
#     ax.imshow(seg, alpha=0.3)
#     ax.axis(False)
#     ax.set_title("grayscale" if grayscale else "colored")