In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import zeus.notebook_utils.syspath as syspath
syspath.add_parent_folder()

In [None]:
import os
import torch
from zeus.core.random import super_seed
from zeus.torch_tools.checkpoints import find_latest_dir
from zeus.utils import TimestampFormat

In [None]:
from kidney.datasets.kaggle import get_reader
from kidney.datasets.offline import create_data_loaders
from kidney.datasets.utils import read_segmentation_info
from kidney.experiments.smp import SMPExperiment, parse_fold_keys
from kidney.utils.checkpoints import CheckpointsStorage, load_experiment

In [None]:
import ast
import re
from typing import Any, Dict, Type
from zeus.utils import named_match

def parse_string(filename: str, patterns: Dict) -> Dict[str, Any]:
    entries = []
    for name, template in patterns.items():
        if template is int or template is float:
            template = {int: r"\d+", float: r"\d+.\d+"}[template]
        entries.append(fr"{name}=(?P<{name}>{template})")
    regex = "_".join(entries)
    return named_match(pattern=regex, string=filename)

In [None]:
assert (
    parse_string("epoch=14_avg_val_loss=0.0403.ckpt", {'epoch': r'\d+', 'avg_val_loss': r'\d+.\d+'}) ==
    {'epoch': 14, 'avg_val_loss': 0.0403}
)

# Computing metrics on checkpoints

In [None]:
CHECKPOINTS = "/home/ck/experiments/smp/checkpoints/"

In [None]:
reader = get_reader()

In [None]:
storage = CheckpointsStorage(CHECKPOINTS)

In [None]:
benchmark = storage.fetch_available_checkpoints("avg_val_loss", best_checkpoint_per_date=False)[-1]

In [None]:
checkpoint_files, meta_file = benchmark["checkpoints"], benchmark["meta"]

In [None]:
device = torch.device("cuda:1")

In [None]:
total = len(checkpoint_files)

In [None]:
from operator import itemgetter

sorted_files = [
    filename
    for filename, _ in 
    sorted([
        (fn, parse_string(fn, {"epoch": int})["epoch"]) 
        for fn in checkpoint_files
    ], key=itemgetter(1))
]

In [None]:
sorted_files[-3:]

In [None]:
metrics_per_checkpoint = []

for i, checkpoint_file in enumerate(sorted_files):
    print(f"[{i+1:3d}/{total:3d}] inference: {checkpoint_file}")
    
    experiment, meta = load_experiment(SMPExperiment, checkpoint_file, meta_file)
    
    experiment.to(device)
    params = meta["params"]
    super_seed(params.seed)
    
    loaders = create_data_loaders(
        reader=reader,
        valid_keys=parse_fold_keys(params.fold) if params.fold is not None else params.fold,
        transformers=meta["transformers"],
        samples=read_segmentation_info(params.dataset, file_format=params.file_format),
        num_workers=params.num_workers,
        batch_size=params.batch_size,
        multiprocessing_context=params.data_loader_multiprocessing_context
    )
        
    with torch.no_grad():
        metrics = {"train": [], "valid": []}
        for name, loader in loaders.items():
            for batch in loader:
                batch = {key: tensor.to(device) for key, tensor in batch.items()}
                outputs = experiment(batch)
                batch_metrics = {
                    metric.name.replace(' ', '_'): metric(outputs, batch).item() 
                    for metric in experiment.metrics
                }
                metrics[name].append(batch_metrics)
                
    metrics_per_checkpoint.append({
        "order": i, 
        "filename": checkpoint_file, 
        "batch_metrics": metrics
    })
            
    del experiment, loaders

In [None]:
torch.save(metrics_per_checkpoint, "/home/ck/benchmark.pth")

# Loading saved benchmark

In [None]:
from collections import defaultdict, OrderedDict
import altair as alt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from zeus.plotting.style import notebook_style
from zeus.plotting.utils import axes, calculate_layout

In [None]:
notebook_style(override={'axes.grid': True, 'figure.figsize': (12, 8)}).apply()

In [None]:
benchmark = torch.load("/home/ck/benchmark.pth")

In [None]:
table = []
for checkpoint in benchmark:
    metrics = checkpoint["batch_metrics"]
    from_file = parse_string(checkpoint["filename"], {"avg_val_loss": float})
    record = OrderedDict()
    record["epoch"] = checkpoint["order"]
    record["avg_val_loss"] = from_file["avg_val_loss"]
    for subset in metrics.keys():
        collected = defaultdict(list)
        for batch in metrics[subset]:
            for metric, value in batch.items():
                collected[metric].append(value)
        avg, std = {}, {}
        for name, values in collected.items():
            avg[name], std[name] = np.mean(values), np.std(values)        
        record.update([(f"{subset}_mean_{k}", v) for k, v in avg.items()])
        record.update([(f"{subset}_std_{k}", v) for k, v in std.items()])
    table.append(record)

In [None]:
table = pd.DataFrame(table)

In [None]:
prefixes = ["train_mean_", "valid_mean_"]
wide = pd.wide_to_long(table, prefixes, i="epoch", j="metric", suffix="\w+")
wide = wide[prefixes].rename(columns=dict(zip(prefixes, ["train", "valid"])))
wide

In [None]:
def plot(metric, ax=None):
    ax = axes(ax=ax)
    ax = wide.xs(metric, level=1).plot(ax=ax)
    ax.set_title(metric.title())
    return ax

In [None]:
metric_names = {name for _, name in wide.index}

_, axs = plt.subplots(*calculate_layout(len(metric_names)), figsize=(30, 20))

for ax in axs.flat:
    ax.axis(False)

for ax, metric in zip(axs.flat, metric_names):
    plot(metric, ax=ax)
    ax.axis(True)

In [None]:
metrics_df = pd.melt(wide.reset_index(), id_vars=["epoch", "metric"], var_name="subset")

In [None]:
metrics_df

In [None]:
chart_train = alt.Chart(metrics_df.query("subset == 'train'")).mark_line().encode(x="epoch", y="value", color="metric").properties(title="train")
chart_valid = alt.Chart(metrics_df.query("subset == 'valid'")).mark_line().encode(x="epoch", y="value", color="metric").properties(title="valid")
chart_train | chart_valid

In [None]:
for metric in ("dice", "balanced_accuracy", "recall", "precis")
wide.xs("dice", level=1)["valid"]

In [None]:
precision = wide.xs("precision", level=1)
recall = wide.xs("recall", level=1)

f1_score = pd.DataFrame({
    "epoch": precision.index,
    "metric": ["f1_score"] * len(precision),
})

for subset in ("train", "valid"):
    p, r = precision[subset], recall[subset] 
    f1 = 2*p*r/(p + r)
    f1_score[subset] = f1

In [None]:
metrics_df = pd.concat([wide.reset_index(), f1_score]).reset_index(drop=True)

In [None]:
best = []
for metric in ("dice", "precision", "recall", "f1_score", "balanced_accuracy"):
    df = metrics_df.query(f"metric == '{metric}'")
    best_index = df["valid"].argmax()
    record = df.iloc[best_index]
    best.append({"metric": metric, "epoch": record.epoch, "best": record.valid})
best = pd.DataFrame(best)

In [None]:
best["filename"] = best.epoch.map(lambda epoch: sorted_files[epoch])

In [None]:
best

In [None]:
best.filename.tolist()

In [None]:
meta_file