In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import random
from collections import defaultdict
from pathlib import Path
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
import rasterio
from kidney.datasets.kaggle import get_reader, SampleType, DatasetReader
from zeus.core.random import super_seed
from zeus.utils import list_files
from zeus.plotting.utils import axes, calculate_layout

In [None]:
super_seed(1)

## Images preview

In [None]:
def print_dataset_info(reader: DatasetReader, sample_type: SampleType):
    reader = get_reader()
    train_keys = reader.get_keys(sample_type)
    identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
    for key in train_keys:
        meta = reader.fetch_meta(key)
        with rasterio.open(meta["tiff"], transform=identity) as dataset:
            height, width = shape = dataset.shape
            has_mask = "[trn]" if meta["mask"] is not None else "[tst]"
            print(has_mask, key, height, width, dataset.indexes)

In [None]:
print_dataset_info(get_reader(), SampleType.All)

In [None]:
PREPARED_DIR = "/mnt/fast/data/kidney/images_32_1024"

In [None]:
def read_png_images(folder: str):
    samples = defaultdict(dict)
    for fn in list_files(folder):
        image_type, image_id = Path(fn).stem.split(".")
        samples[image_id][image_type] = fn
        if image_type == "img":
            samples[image_id]["masked"] = False
            samples[image_id]["colored"] = colored_image(fn)
        if image_type == "seg":
            samples[image_id]["masked"] = True
            samples[image_id]["mask_image_ratio"] = non_zero_pixels_ratio(fn)
    return samples
        
def colored_image(filename: str) -> bool:
    image = PIL.Image.open(filename)
    return image.mode == "RGB"

def non_zero_pixels_ratio(filename: str) -> float:
    arr = np.asarray(PIL.Image.open(filename))
    return np.where(arr == 255, 1, 0).mean()

In [None]:
images = read_png_images(PREPARED_DIR)

In [None]:
n = 14
keys = random.sample(images.keys(), k=n*n)
canvas = axes(subplots=(n, n), figsize=(30, 30))
decode = {}
for i, (key, ax) in enumerate(zip(keys, canvas.flat)):
    img = PIL.Image.open(images[key]["img"])
    ax.imshow(img, cmap=None if images[key]["colored"] else "gray")
    ax.axis(False)
    ax.set_title(f"{i}")
    decode[i] = key

In [None]:
# anchors = (3, 25, 42, 54, 57, 60, 64, 104, 135)
# anchors = (0, 11, 19, 56, 81, 91, 101, 122, 178, 195)
# anchors = (81, 0, 11, 91, 101, 6, 149)
anchors = (149, 91, 101, 81, 109, 93)
n = len(anchors)
canvas = axes(subplots=(1, n), figsize=(18, 4))
filenames = []
for i, ax in enumerate(canvas.flat):
    path = images[decode[anchors[i]]]["img"]
    filenames.append(path)
    img = PIL.Image.open(path)
    ax.imshow(img)
    ax.axis(False)
    ax.set_title(anchors[i])

In [None]:
import json
import cv2 as cv

In [None]:
class ColorTransfer:
    
    def __init__(self, mean: np.ndarray, std: np.ndarray, ref: str = "default"):
        self.mean = mean
        self.std = std
        self.ref = ref
        
    @staticmethod
    def read_json(filename: str):
        with open(filename, "r") as fp:
            contents = json.load(fp)
        mean = [np.array(c) for c in contents["mean"]]
        std = [np.array(c) for c in contents["std"]]
        return ColorTransfer(mean, std)
    
    def write_json(self, filename: str):
        with open(filename, "w") as fp:
            json.dump({
                "mean": [c.tolist() for c in self.mean], 
                "std": [c.tolist() for c in self.std],
                "reference": self.ref
            }, fp)
    
    def transfer_image(self, target: np.ndarray, as_rgb: bool = True):
        channels = []
        for i, channel in enumerate(cv.split(target)):
            channel -= channel.mean()
            channel *= channel.std()/(self.std[i] + 1e-8)
            channel += self.mean[i]
            channel = channel.clip(0, 255)
            channels.append(channel)
        image = cv.merge(channels).astype(np.uint8)
        if as_rgb:
            image = cv.cvtColor(image, cv.COLOR_LAB2RGB)
        return image
            

def read_lab(filename: str):
    bgr = cv.imread(filename)
    lab = cv.cvtColor(bgr, cv.COLOR_BGR2LAB).astype(np.float32)
    return lab


def channel_stats(image: np.ndarray):
    channels = cv.split(image)
    mean = [c.mean() for c in channels]
    std = [c.std() for c in channels]
    return mean, std

In [None]:
import os
from pathlib import Path
n = len(filenames)
canvas = axes(subplots=(n, n), figsize=(n*3, n*3))
output_dir = "/mnt/fast/data/color_transfers"
os.makedirs(output_dir, exist_ok=True)
for i in range(n):
    lab = read_lab(filenames[i])
    mean, std = channel_stats(lab)
    t = ColorTransfer(mean, std, filenames[i])
    image_id = Path(filenames[i]).stem.split(".")[-1]
    t.write_json(os.path.join(output_dir, f"{image_id}.json"))
    for j in range(n):
        index = i*n + j
        transferred = t.transfer_image(read_lab(filenames[j]))
        canvas.flat[index].imshow(transferred)
        canvas.flat[index].axis(False)
        if i == j:
            canvas.flat[index].set_title(anchors[i])

In [None]:
!ls -1 {output_dir}

In [None]:
!rm -rf {output_dir}

In [None]:
import os
import uuid
json_files = [os.path.join(output_dir, fn) for fn in os.listdir(output_dir)]
n = len(filenames)
canvas = axes(subplots=(n, n), figsize=(n*3, n*3))
for i in range(n):
    lab = read_lab(filenames[i])
    mean, std = channel_stats(lab)
    t = ColorTransfer.read_json(json_files[i])
    for j in range(n):
        index = i*n + j
        transferred = t.transfer_image(read_lab(filenames[j]))
        canvas.flat[index].imshow(transferred)
        canvas.flat[index].axis(False)

## Image Groups Preview

In [None]:
def images_summary(images: Dict) -> pd.DataFrame:
    return pd.DataFrame([
        {
            "image_id": image_id,
            "masked": info["masked"],
            "colored": info["colored"],
            "ratio": info["mask_image_ratio"] if info["masked"] else np.nan,
        }
        for image_id, info in images.items()
    ])  

In [None]:
info = images_summary(images)

In [None]:
colored = info.query("colored")
colored_no_mask = colored.query("ratio == 0")
colored_small_mask = colored.query("ratio > 0 and ratio <= 0.05")
colored_medium_mask = colored.query("ratio > 0.05 and ratio <= 0.20")
colored_large_mask = colored.query("ratio > 0.20")

grayscale = info.query("not colored")
grayscale_no_mask = grayscale.query("ratio == 0")
grayscale_small_mask = grayscale.query("ratio > 0 and ratio <= 0.05")
grayscale_medium_mask = grayscale.query("ratio > 0.05 and ratio <= 0.20")
grayscale_large_mask = grayscale.query("ratio > 0.20")

image_groups = {
    "colored": {
        "empty": colored_no_mask.image_id.tolist(),
        "small": colored_small_mask.image_id.tolist(),
        "medium": colored_medium_mask.image_id.tolist(),
        "large": colored_large_mask.image_id.tolist(),
    },
    "grayscale": {
        "empty": grayscale_no_mask.image_id.tolist(),
        "small": grayscale_small_mask.image_id.tolist(),
        "medium": grayscale_medium_mask.image_id.tolist(),
        "large": grayscale_medium_mask.image_id.tolist(),
    }
}

In [None]:
# info.query("colored").ratio.plot.hist(bins=20)

In [None]:
# info.query("not colored").ratio.plot.hist(bins=20)

In [None]:
for df in (
    colored,
    colored_no_mask,
    colored_small_mask,
    colored_medium_mask,
    colored_large_mask,
    grayscale,
    grayscale_no_mask,
    grayscale_small_mask,
    grayscale_medium_mask,
    grayscale_large_mask,
):
    print(df.shape)

In [None]:
n = 7
for color, mask_groups in image_groups.items():
    for mask_size, image_ids in mask_groups.items():
        keys = random.sample(image_ids, k=n*n)
        canvas = axes(subplots=(n, n), figsize=(20, 20))
        for key, ax in zip(keys, canvas.flat):
            x = images[key]
            img = np.asarray(PIL.Image.open(x["img"]))
            seg = np.asarray(PIL.Image.open(x["seg"]))
            grayscale = img.ndim == 2
            ax.imshow(img, cmap="gray" if color == "grayscale" else None)
            ax.imshow(seg, alpha=0.3)
            ax.axis(False)
            # ax.set_title("grayscale" if grayscale else "colored")
        plt.gcf().suptitle(f"{color} ({mask_size})")

In [None]:
# x = samples["8242609fa_19584_10759_20608_11783"]
# img = np.asarray(PIL.Image.open(x["img"]))
# seg = np.asarray(PIL.Image.open(x["seg"]))
# plt.figure(figsize=(10,10))
# plt.imshow(img)
# plt.imshow(seg, alpha=0.3)
# plt.show()

In [None]:
# n = 7
# keys = random.sample(samples.keys(), k=n*n)
# canvas = axes(subplots=(n, n), figsize=(20, 20))
# for key, ax in zip(keys, canvas.flat):
#     x = samples[key]
#     img = np.asarray(PIL.Image.open(x["img"]))
#     seg = np.asarray(PIL.Image.open(x["seg"]))
#     grayscale = img.ndim == 2
#     ax.imshow(img, cmap="gray" if grayscale else None)
#     ax.imshow(seg, alpha=0.3)
#     ax.axis(False)
#     ax.set_title("grayscale" if grayscale else "colored")

## Loaders

In [None]:
from kidney.datasets.offline import create_data_loaders
from kidney.datasets.transformers import get_transformers, IntensityNormalization
from kidney.datasets.utils import read_segmentation_info
from pytorch_lightning.utilities import AttributeDict

In [None]:
reader = get_reader()

In [None]:
train_keys = reader.get_keys(SampleType.Labeled)

In [None]:
train_keys, valid_keys = train_keys[:-1], train_keys[-1]
train_keys, valid_keys

In [None]:
get_transformers??

In [None]:
transformers = get_transformers(AttributeDict(
    aug_pipeline="strong",
    aug_normalization_method=IntensityNormalization.TorchvisionSegmentation,
    dataset=PREPARED_DIR,
    model_input_size=1024,
    model_input_image_key="img",
    model_input_mask_key="seg",
))

In [None]:
samples = read_segmentation_info(PREPARED_DIR, file_format="bbox")

In [None]:
samples[:3]

In [None]:
loaders = create_data_loaders(
    reader=reader,
    valid_keys=[valid_keys],
    transformers=transformers,
    samples=samples,
    num_workers=0,
    batch_size=24,
)

In [None]:
loaders.keys()

In [None]:
batch = next(iter(loaders["train"]))