In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import ast
import re
from itertools import product
from operator import itemgetter
from typing import Any, Dict, Type

import numpy as np
import pandas as pd
import wandb
from zeus.utils import named_match

from kidney.datasets.kaggle import get_reader
from kidney.datasets.offline import create_data_loaders
from kidney.datasets.utils import read_segmentation_info
from kidney.experiments.smp import SMPExperiment, parse_fold_keys
from kidney.utils.checkpoints import CheckpointsStorage, load_experiment

In [None]:
api = wandb.Api()

In [None]:
run = api.run("devforfu/kidney/Fri_02_Apr__21_01_46")

In [None]:
metrics = ("recall", "precision", "dice", "balanced accuracy", "loss")

dfs = {}

for subset in ("trn", "val"):
    subset_metrics = {f"avg_{subset}_{metric}" for metric in metrics}
    history = pd.DataFrame([
        {key: row.get(key, np.nan) for key in subset_metrics} 
        for row in run.scan_history(page_size=1000)
    ])  
    dfs[subset] = history

In [None]:
dfs = {
    subset: df[df.notnull().sum(axis=1) != 0].reset_index(drop=True) 
    for subset, df in dfs.items()
}

In [None]:
columns = [f"avg_{x}_{y}" for x, y in product(("trn", "val"), metrics)]

In [None]:
df_metrics = dfs["trn"].join(dfs["val"])[columns]

In [None]:
val_columns = [column for column in columns if "_val_" in column]

In [None]:
val_columns

In [None]:
precision = df_metrics["avg_val_precision"]
recall = df_metrics["avg_val_recall"]
df_metrics["avg_val_f1_score"] = 2*precision*recall/(precision + recall)

In [None]:
best_val = {
    "recall": df_metrics["avg_val_recall"].argmax(),
    "precision": df_metrics["avg_val_precision"].argmax(),
    "f1_score": df_metrics["avg_val_f1_score"].argmax(),
    "dice": df_metrics["avg_val_dice"].argmax(),
    "balanced_accuacy": df_metrics["avg_val_balanced accuracy"].argmax(),
    "loss": df_metrics["avg_val_loss"].argmin()
}

In [None]:
best_val

In [None]:
def parse_string(filename: str, patterns: Dict) -> Dict[str, Any]:
    entries = []
    for name, template in patterns.items():
        if template is int or template is float:
            template = {int: r"\d+", float: r"\d+.\d+"}[template]
        entries.append(fr"{name}=(?P<{name}>{template})")
    regex = "_".join(entries)
    return named_match(pattern=regex, string=filename)

CHECKPOINTS = "/home/ck/experiments/smp/checkpoints/"
reader = get_reader()
storage = CheckpointsStorage(CHECKPOINTS)
benchmark = storage.fetch_available_checkpoints("avg_val_loss", best_checkpoint_per_date=False)[-1]
checkpoint_files, meta_file = benchmark["checkpoints"], benchmark["meta"]

sorted_files = [
    filename
    for filename, _ in 
    sorted([
        (fn, parse_string(fn, {"epoch": int})["epoch"]) 
        for fn in checkpoint_files
    ], key=itemgetter(1))
]

In [None]:
best_checkpoints = {metric: sorted_files[index] for metric, index in best_val.items()}

In [None]:
best_checkpoints

In [None]:
from zeus.utils import list_files
for csv_file in list_files("/mnt/fast/data/kidney/outputs"):
    break

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv(csv_file).set_index("id")

In [None]:
reader = get_reader()

In [None]:
for key in df.index:
    sample = reader.fetch_one(key)
    # df.loc[key].predicted
    break

In [None]:
sample["image"].shape

In [None]:
from kidney.utils.mask import rle_decode
from kidney.datasets.kaggle import SampleType

In [None]:
predictions = {}
train_keys = reader.get_keys(SampleType.Labeled)
for key in train_keys:
    sample = reader.fetch_one(key)
    h, w = sample["image"].shape[:2]
    predicted_mask = rle_decode(df.loc[key].predicted, (h, w))
    predictions[key] = {"image": sample["image"], "pred": predicted_mask, "gt": sample["mask"]}
    break

In [None]:
import cv2 as cv
import matplotlib.pyplot as plt

In [None]:
key = "0486052bb"
size = (2048, 2048)
img = cv.resize(predictions[key]["image"], size)
seg_pred = cv.resize(predictions[key]["pred"], size)
seg_true = cv.resize(predictions[key]["gt"], size)

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(img)
plt.imshow(np.where(seg_true == 1, 1, 0), alpha=0.2)
plt.imshow(np.where(seg_pred == 1, 2, 0), alpha=0.2)
plt.show()

In [None]:
# x = samples["8242609fa_19584_10759_20608_11783"]
# img = np.asarray(PIL.Image.open(x["img"]))
# seg = np.asarray(PIL.Image.open(x["seg"]))
# plt.figure(figsize=(10,10))
# plt.imshow(img)
# plt.imshow(seg, alpha=0.3)
# plt.show()