In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from pprint import pprint

import numpy
import pandas
# import napari
import seaborn
import torch
from imageio import imread
from ruamel.yaml import YAML
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt

from hylfm.eval.metrics import compute_metrics_individually, init_metrics
yaml = YAML(typ="safe")

In [None]:
seaborn.set_style("darkgrid")
seaborn.set_context("notebook")  # paper, notebook, talk, poster

In [None]:
# root = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/test/heart/z_out49/contin_validate_f4/20-11-10_14-02-53")
# z_offset = 29
# zmod = 139
# validation_step_dirs = sorted(root.glob("test_dynamic_*/run000"))
root = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/z_out49/dualview_single_lfm_static_f4_center49/20-11-09_18-09-02")
z_offset = 29
zmod = 189
validation_step_dirs = sorted(root.glob("validate_train_01/run*"))[:-1]


assert root.exists(), root
pprint([p.name for p in validation_step_dirs])

metrics_config = yaml.load(Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/configs/metrics/heart_dynamic.yml"))
metrics_instances = init_metrics(metrics_config)

In [None]:
ls_slices = numpy.stack([imread(p) for p in sorted(validation_step_dirs[0].glob("ds0-0/ls_slice/*.tif"))])
print(ls_slices.shape, flush=True)
assert (ls_slices.shape[0] % zmod) == 0, (ls_slices.shape[0], zmod)
all_preds = []
for step_dir in tqdm(validation_step_dirs):
    preds = numpy.stack([imread(p) for p in sorted(step_dir.glob("ds0-0/pred/*.tif"))])
    assert preds.shape == ls_slices.shape, (preds.shape, ls_slices.shape)
    all_preds.append(preds)

In [None]:
data = None
for pred_nr, preds in enumerate(tqdm(all_preds)):
    for step, (pred, ls_slice) in enumerate(zip(preds, ls_slices)):
        # add batch and channel dim
        pred = pred[None, None]
        ls_slice = ls_slice[None, None]

        tensors = {"pred": torch.from_numpy(pred), "ls_slice": torch.from_numpy(ls_slice)}
        computed_metrics = {k: m.value for k, m in compute_metrics_individually(metrics_instances, tensors).items()}
        computed_metrics["step"] = step
        computed_metrics["z"] = step % zmod
        computed_metrics["pred_nr"] = pred_nr
        if data is None:
            data = {k: [v] for k, v in computed_metrics.items()}
        else:
            for mk, mv in computed_metrics.items():
                data[mk].append(mv)

In [None]:
df = pandas.DataFrame.from_dict(data)
df.tail()

In [None]:
# nbins = 8
# df['z_bin'] = pandas.cut(df['z'], bins=nbins, labels=numpy.arange(nbins))
# df.head()

In [None]:
df["z"] = (df["step"] % zmod) + z_offset
df["frame"] = df["pred_nr"] * 241 + df["step"] % zmod + z_offset # + (df["step"] // zmod) * 241
df["swipe_through"] = df["step"] // zmod
df["time [s]"] = df["frame"] * 0.025
df.tail(50)

In [None]:
df_filtered = df
df_filtered = df[df.pred_nr < 5]
df_filtered.z.min()

In [None]:
def plot_scans_grid(metric: str):
    seaborn.set_style("darkgrid", {"axes.facecolor": ".8"})
    seaborn.set_context("talk")  # paper, notebook, talk, poster
    cmap_name = "viridis"
    g = seaborn.relplot(x="z", y=metric, hue="z", legend=False, col="pred_nr", row="swipe_through",
                    palette=cmap_name, height=7, aspect=2,
                    kind="scatter", data=df_filtered)
#     for ax in g.axes:
#         g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z', ax=ax)
#         ax.set_xlim([-z_range_value * 1.1, z_range_value * 1.1])
#         ax.set_ylim([-z_range_value * 1.1, z_range_value * 1.1])
#         ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c=".3")
#     g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z')
#     g.fig.axes[0].set_xlim(0, 9399*0.025)
#     g.fig.tight_layout()
    root = Path("refine_lfd_training_plots")
    root.mkdir(exist_ok=True)
    g.fig.savefig(root / f"{metric}.png")
#     g.add_legend()

plot_scans_grid("ms_ssim-scaled")

In [None]:
def plot_scans(metric: str):
    seaborn.set_style("darkgrid", {"axes.facecolor": ".8"})
    seaborn.set_context("talk")  # paper, notebook, talk, poster
    cmap_name = "viridis"
    g = seaborn.relplot(x="time [s]", y=metric, hue="z", legend=False,
                    palette=cmap_name, height=7, aspect=7,
                    kind="scatter", data=df_filtered)
    g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z')
#     g.fig.axes[0].set_xlim(0, 9399*0.025)
    g.fig.tight_layout()
    root = Path("refine_lfd_training_plots")
    root.mkdir(exist_ok=True)
    g.fig.savefig(root / f"{metric}.png")

plot_scans("ms_ssim-scaled")

In [None]:
plot_scans("ms_ssim-scaled")
plot_scans("ssim-scaled")
plot_scans("nrmse-scaled")
plot_scans("psnr-scaled")
plot_scans("mse_loss-scaled")
plot_scans("smooth_l1_loss-scaled")

In [None]:
# df_bins = df.groupby(["z_bin", "pred_nr"]).mean().reset_index()
# df_bins.head()
# g = seaborn.catplot(x="pred_nr", y="ms_ssim-scaled", hue="pred_nr",
#                 capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
#                 kind="point", data=df)