In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import os
from dataclasses import dataclass
from pathlib import Path
from operator import itemgetter

import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd

from zeus.utils import list_files
from kidney.datasets.kaggle import get_reader
from kidney.utils.mask import rle_decode

In [None]:
PREDICTIONS_DIR = os.path.join(os.environ["DATASET_ROOT"], "predictions")

In [None]:
reader = get_reader()

In [None]:
model_dirs = list_files(PREDICTIONS_DIR)
model_dir = model_dirs[0]
model_dir

In [None]:
def read_predictions(root: str):
    folds = []
    for fn in list_files(root):
        name = Path(fn).stem
        order = int(name.split("_")[-1])
        folds.append((order, fn))
    
    acc, *rest = [
        pd.read_csv(fn).set_index("id")
        for _, fn in sorted(folds, key=itemgetter(0))
    ]
    
    for df in rest:
        acc = pd.merge(acc, df, left_index=True, right_index=True)
    acc.columns = range(len(folds))

    return acc

In [None]:
rle_df = read_predictions(model_dir)
rle_df

In [None]:
sample_key = "0486052bb"

In [None]:
sample = reader.fetch_one(sample_key)

In [None]:
mask_size = sample["image"].shape[:2]

In [None]:
reader.fetch_meta(sample_key).keys()

In [None]:
import numpy as np
from typing import Dict, Tuple, List


@dataclass
class CombinedPrediction:
    predictions: Dict
    mask_size: Tuple[int, int]
        
    def __call__(self, sample_key: str) -> np.ndarray:
        raise NotImplementedError()
        

class MajorityVotePrediction(CombinedPrediction):
    majority: float = 0.5
    
    def __call__(self, sample_key: str) -> np.ndarray:
        rle_masks = self.predictions[sample_key]
        n_folds = len(rle_masks)
        majority_threshold = int(self.majority * n_folds)
        mask_pred = np.zeros(self.mask_size, dtype=np.uint8)
        for fold_name, mask in rle_masks.items():
            mask_pred += rle_decode(mask, self.mask_size)
        mask_pred = mask_pred > majority_threshold
        return mask_pred.astype(np.uint8)

In [None]:
rle_dict = rle_df.to_dict("index")

In [None]:
prediction = MajorityVotePrediction(rle_dict, mask_size)

In [None]:
majority_mask = prediction(sample_key)

In [None]:
thumbnail = cv.resize(majority_mask, (2048, 2048))
thumbnail *= 255
_, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(thumbnail)

In [None]:
thumbnail.mean()

In [None]:
majority_mask.min(), majority_mask.max()