In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from pprint import pprint

import numpy
import pandas
# import napari
import seaborn
import torch
from imageio import imread
from ruamel.yaml import YAML
from tqdm import tqdm

from hylfm.eval.metrics import compute_metrics_individually, init_metrics

yaml = YAML(typ="safe")

# import matplotlib.pyplot as plt

In [None]:
root = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/z_out49/growdynds_f4/20-11-02_11-41-23")
# root = Path("/mnt/c/Users/fbeut/Desktop/lnet_stuff/growdynds")
zmod = 209
test_stage_dirs = sorted(root.glob("test_dynamic_*/"))

pprint([p.name for p in test_stage_dirs])

In [None]:
metrics_config = yaml.load(Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/configs/metrics/heart_dynamic.yml"))
metrics_instances = init_metrics(metrics_config)

In [None]:
assert root.exists(), root
ls_slices = numpy.stack([imread(p) for p in sorted(test_stage_dirs[0].glob("run000/ds0-0/ls_slice/*.tif"))])
print(ls_slices.shape)
assert (ls_slices.shape[0] % zmod) == 0
all_preds = []
for test_stage_dir in tqdm(test_stage_dirs):
    preds = numpy.stack([imread(p) for p in sorted(test_stage_dir.glob("run000/ds0-0/pred/*.tif"))])
    assert preds.shape == ls_slices.shape
    all_preds.append(preds)

In [None]:
data = None
for pred_nr, preds in enumerate(tqdm(all_preds)):
    for step, (pred, ls_slice) in enumerate(zip(preds, ls_slices)):
        # add batch and channel dim
        pred = pred[None, None]
        ls_slice = ls_slice[None, None]

        tensors = {"pred": torch.from_numpy(pred), "ls_slice": torch.from_numpy(ls_slice)}
        computed_metrics = {k: m.value for k, m in compute_metrics_individually(metrics_instances, tensors).items()}
        computed_metrics["step"] = step
        computed_metrics["z"] = step % zmod
        computed_metrics["pred_nr"] = pred_nr
        if data is None:
            data = {k: [v] for k, v in computed_metrics.items()}
        else:
            for mk, mv in computed_metrics.items():
                data[mk].append(mv)

In [None]:
df = pandas.DataFrame.from_dict(data)
df.tail()

In [None]:
nbins = 8
df['z_bin'] = pandas.cut(df['z'], bins=nbins, labels=numpy.arange(nbins))
df.head()

In [None]:
# df_bins = df.groupby(["z_bin", "pred_nr"]).mean().reset_index()
# df_bins.head()

In [None]:
seaborn.set_style("darkgrid")

In [None]:
g = seaborn.catplot(x="pred_nr", y="ms_ssim-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
g = seaborn.catplot(x="pred_nr", y="ssim-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
g = seaborn.catplot(x="pred_nr", y="smooth_l1_loss-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
g = seaborn.catplot(x="pred_nr", y="nrmse-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
g = seaborn.catplot(x="pred_nr", y="psnr-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
g = seaborn.catplot(x="pred_nr", y="mse_loss-scaled", hue="z_bin", col="z_bin", col_wrap=4,
                capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
                kind="point", data=df)

In [None]:
# other plot drafts

In [None]:
# g = seaborn.lmplot(
#     data=df_bins,
#     x="pred_nr", y="ms_ssim-scaled", hue="z_bin",
#     height=5
# )

# # Show the results of a linear regression within each dataset
# seaborn.lmplot(x="pred_nr", y="ms_ssim-scaled", col="z_bin", hue="z_bin", data=df_bins,
#            col_wrap=3, ci=None, palette="muted", height=4,
#            scatter_kws={"s": 50, "alpha": 1})


