In [None]:
%reload_ext autoreload
%autoreload 2
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import glob
import os
from collections import defaultdict
from dataclasses import dataclass
from os.path import join
from pprint import pprint as pp
from typing import Dict, List

import numpy as np
import pandas as pd
import PIL.Image
import seaborn as sns

from zeus.utils.misc import named_match
from zeus.plotting.style import notebook_style
from zeus.plotting.utils import axes
from kidney.utils.plotting import overlay

In [None]:
sns.reset_orig()
_ = notebook_style({'xtick.labelsize': 18, 'ytick.labelsize': 18})

In [None]:
DATA = '/mnt/fast/data/kidney_patches'

In [None]:
images_info = defaultdict(dict)

for image_type in ('img', 'seg'):
    pattern = f'{image_type}\.(?P<dx>\d+)\.(?P<dy>\d+)\.(?P<stride>\d+).png'
    for path in glob.glob(f'{DATA}/{image_type}.*.png'):
        m = named_match(pattern, path)
        dx, dy = m['dx'], m['dy']
        identifier = f"{dx}.{dy}"
        images_info[identifier][
            'mask' if image_type == 'seg' else 'image'
        ] = path
        images_info[identifier]['position'] = dx, dy

In [None]:
keys = list(images_info.keys())

In [None]:
for key in keys[:3]:
    print(key)
    pp(images_info[key])
    print()

In [None]:
for i, ax in enumerate(axes(subplots=(8, 8), figsize=(12, 12)).flat):
    info = images_info[keys[i]] 
    image = np.asarray(PIL.Image.open(info['image']))
    mask = np.asarray(PIL.Image.open(info['mask']))
    combined = overlay(image, mask)
    ax.imshow(combined)
    ax.axis('off')
    ax.set_title(i)

In [None]:
blacks = [keys[x] for x in (1, 3, 11, 25, 40, 41, 43, 44, 53)]
whites = [keys[x] for x in (2, 5, 9, 13, 15, 19, 22, 26, 28, 29, 30, 33, 36, 38, 42, 45, 46, 48, 57, 59, 62)]
strong = [keys[x] for x in (0, 4, 6, 7, 8, 12, 16, 17, 18, 21, 23, 27, 31, 32, 34, 35, 37, 49, 50, 51, 52, 54, 56, 58, 60, 61)]
weak = [keys[x] for x in (14, 24, 47)]

In [None]:
def read_image_as_numpy(path: str) -> np.ndarray:
    return np.asarray(PIL.Image.open(path))

In [None]:
def pixel_histogram(image: np.ndarray):
    hist = (
        pd.cut(
            image.ravel(),
            bins=range(0, 256+4, 4),
            labels=range(0, 256, 4),
            right=False
        )
        .value_counts()
        .rename('count')
    )
    return hist

In [None]:
def mean_pixel_histogram_for_keys(meta: Dict, keys: List[str]):
    acc = None
    for key in keys:
        path = meta[key]['image']
        image = read_image_as_numpy(path)
        hist = pixel_histogram(image)
        acc = hist if acc is None else (acc + hist)
    acc /= len(keys)
    acc = acc.astype(int)
    return acc

In [None]:
def plot_pixel_frequency_diagram(pixel_hist: np.ndarray, title: str = '', threshold: int = 1000, ax=None):
    data = (
        pixel_hist
        .reset_index()
        .sort_values(by='count')
        .reset_index(drop=True)
    )
    ax = sns.barplot(
        x='index', y='count', 
        data=data, order=data['index'], 
        ax=axes(ax=ax, figsize=(11, 6))
    )
    labels = [t.get_text() for t in ax.get_xticklabels()]
    ax.hlines(threshold, 0, len(pixel_hist), linestyles='--')
    ax.set_xticklabels(labels, rotation=90)
    ax.set_xlabel('Pixel Value')
    ax.set_title(title)
    return ax

In [None]:
subplots = axes(subplots=(2, 2), figsize=(16, 12)).flat
groups = zip([
    ('Black', blacks),
    ('White', whites),
    ('Weak', weak),
    ('Strong', strong)
], subplots)
threshold = 5_000
for (title, keys), ax in groups:
    histogram = mean_pixel_histogram_for_keys(images_info, keys)
    ax = plot_pixel_frequency_diagram(histogram, title=title, ax=ax, threshold=threshold)