In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from pprint import pprint
from typing import Dict, Sequence

import numpy
import pandas
# import napari
import seaborn
import torch
from imageio import imread
from ruamel.yaml import YAML
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt

from hylfm.eval.metrics import compute_metrics_individually, init_metrics
yaml = YAML(typ="safe")

In [None]:
def get_validate_df(name, step_dirs, z_mod):
    metrics_config = yaml.load(Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/configs/metrics/heart_dynamic.yml"))
    metrics_instances = init_metrics(metrics_config)
    
    all_preds = []
    all_ls_slices = []
    pred_nrs = []
    for step_dir in tqdm(step_dirs, desc=f"load raw data for {name}"):
        assert step_dir.name == "run000"
        pred_nr = int(step_dir.parent.name.split("_")[-1])
        pred_nrs.append(pred_nr)
        ls_slices = numpy.stack([imread(p) for p in sorted(step_dir.glob("ds0-0/ls_slice/*.tif"))])
        assert (ls_slices.shape[0] % z_mod) == 0
        all_ls_slices.append(ls_slices)
        preds = numpy.stack([imread(p) for p in sorted(step_dir.glob("ds0-0/pred/*.tif"))])
        assert preds.shape == ls_slices.shape, (preds.shape, ls_slices.shape)
        all_preds.append(preds)

    data = None
    for pred_nr, preds, ls_slices in tqdm(zip(pred_nrs, all_preds, all_ls_slices), total=len(all_preds), desc=f"comp. metrics for {name}"):
        step = 0
        for idx, (pred, ls_slice) in enumerate(zip(preds, ls_slices)):
            # add batch and channel dim
            pred = pred[None, None]
            ls_slice = ls_slice[None, None]

            tensors = {"pred": torch.from_numpy(pred), "ls_slice": torch.from_numpy(ls_slice)}
            computed_metrics = {k: m.value for k, m in compute_metrics_individually(metrics_instances, tensors).items()}
            computed_metrics["idx"] = idx
            computed_metrics["pred_nr"] = pred_nr
            if data is None:
                data = {k: [v] for k, v in computed_metrics.items()}
            else:
                for mk, mv in computed_metrics.items():
                    data[mk].append(mv)

    df = pandas.DataFrame.from_dict(data)
    df["swipe_through"] = df["pred_nr"]
    df["pred_nr"] = 0
    return df

def get_refine_ls_slices(z_mod, fish2: bool):
    if fish2:
        ls_dir = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/test/heart/validate_fish2/from_static_heart/20-11-12_15-11-48/test_dynamic_00/run000")
    else:
        ls_dir = Path(
            "/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/z_out49/dualview_single_lfm_static_f4_center49/20-11-09_18-09-02/validate_train_01/run000"
        )

    ls_slices = numpy.stack([imread(p) for p in sorted(ls_dir.glob("ds0-0/ls_slice/*.tif"))])
    assert (ls_slices.shape[0] % z_mod) == 0, (ls_slices.shape[0], z_mod)
    return ls_slices



def get_refine_df(name, step_dirs, z_mod, ls_slices):
    assert (ls_slices.shape[0] % z_mod) == 0, (ls_slices.shape[0], z_mod)

    metrics_config = yaml.load(Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/configs/metrics/heart_dynamic.yml"))
    metrics_instances = init_metrics(metrics_config)

    all_preds = []
    pred_nrs = []
    for step_dir in tqdm(step_dirs, desc=f"load raw data for ls_slices"):
        assert step_dir.name.startswith("run")
        pred_nr = int(step_dir.name.replace("run", ""))
        pred_nrs.append(pred_nr)

        preds = numpy.stack([imread(p) for p in sorted(step_dir.glob("ds0-0/pred/*.tif"))])
        assert preds.shape == ls_slices.shape, (preds.shape, ls_slices.shape)
        all_preds.append(preds)

    data = None
    for pred_nr, preds in zip(tqdm(pred_nrs, desc=f"comp. metrics for {name}"), all_preds):
        for idx, (pred, ls_slice) in enumerate(zip(preds, ls_slices)):
            # add batch and channel dim
            pred = pred[None, None]
            ls_slice = ls_slice[None, None]

            tensors = {"pred": torch.from_numpy(pred), "ls_slice": torch.from_numpy(ls_slice)}
            computed_metrics = {k: m.value for k, m in compute_metrics_individually(metrics_instances, tensors).items()}
            computed_metrics["idx"] = idx
            computed_metrics["pred_nr"] = pred_nr
            if data is None:
                data = {k: [v] for k, v in computed_metrics.items()}
            else:
                for mk, mv in computed_metrics.items():
                    data[mk].append(mv)

    df = pandas.DataFrame.from_dict(data)
    df["swipe_through"] = df["idx"] // z_mod
    return df

def get_df(name, *, pred_nrs = (0,), ls_slices: Dict[int, numpy.ndarray]):
    if name == "validate_from_static_heart":
        root = Path(
            # "/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/test/heart/z_out49/contin_validate_f4/20-11-10_14-02-53"
            "/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/test/heart/z_out49/contin_validate_f4/20-11-11_19-35-43"
        )

        assert root.exists(), root
        z_min = 29
        z_mod = 189
        step_dirs = sorted(root.glob("test_dynamic_*/run000"))
    elif name == "validate_fish2/from_static_heart":
        root = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/test/heart/validate_fish2/from_static_heart/20-11-12_15-11-48")
        z_min = 29
        z_mod = 189
        step_dirs = [root / "test_dynamic_00/run000"]
    elif name.startswith("refine_from"):
        z_min = 29
        z_mod = 189
        common_root = Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/z_out49/")
        names_map = {"refine_from_lfd_heart": "dualview_single_lfm_static_f4_center49"}
        times_map = {
            "refine_from_bad_static_heart": "20-11-11_20-00-16",
            "refine_from_lfd_heart": "20-11-09_18-09-02",
            "refine_from_medium_beads": "20-11-11_13-09-30",
            "refine_from_mednlarge_beads": "20-11-11_19-55-02",
            "refine_from_static_heart": "20-11-11_19-48-09",
        }
        root = common_root / names_map.get(name, name) / times_map[name]
        assert root.exists(), root
        step_dirs = sorted(root.glob("validate_train_01/run*"))[:-1]
    else:
        raise NotImplementedError(name)

    if name.startswith("validate_from"):
        # if "from_static" in name:
        #     def idx2z(idx):
        #         return z_min + (idx % z_mod)
        # else:
        #     raise NotImplementedError(name)
        _get_df = get_validate_df
        step_dirs = step_dirs[:4]
        get_df_kwargs = {}
    elif name.startswith("refine") or name.startswith("validate_fish2"):
        _get_df = get_refine_df
        step_dirs = numpy.asarray(step_dirs)[pred_nrs]

        if (z_mod, "fish2" in name) not in ls_slices:
            ls_slices[z_mod] = get_refine_ls_slices(z_mod, "fish2" in name)

        get_df_kwargs = {"ls_slices": ls_slices[z_mod]}
    else:
        raise NotImplemented(name)

    def idx2z(idx):
        return z_min + z_mod - 1 - (idx % z_mod)

    df = _get_df(name, step_dirs, z_mod, **get_df_kwargs)
    df["z"] = df["idx"].apply(idx2z) - 120
    df["frame"] = df["swipe_through"] * 241 + 241 - df["z"]
    df["time [s]"] = df["frame"] * 0.025
    df["run_name"] = name
    return df

def add_df(df, ls_slices):
    dfs = [df]
    for name in names:
        dfs.append(get_df(name, pred_nrs=pred_nrs, ls_slices=ls_slices))

    return pandas.concat(dfs)

def get_dfs(*names, pred_nrs):
    dfs = []
    ls_slices = {}
    for name in names:
        dfs.append(get_df(name, pred_nrs=pred_nrs, ls_slices=ls_slices))

    return pandas.concat(dfs)

In [None]:
df = get_dfs("validate_fish2/from_static_heart", pred_nrs=[0])
df.head()

In [None]:
df = get_dfs("validate_from_static_heart", "refine_from_static_heart", pred_nrs=[0])
df["network"] = df.run_name
mask = df.pred_nr >= 0
df.loc[mask, "network"] = df[mask].apply(lambda row: f"refinement step: {row.pred_nr}", axis=1)

# df = get_dfs("validate_from_static_heart", pred_nrs=[0,35])
# df = get_dfs("refine_from_lfd_heart", pred_nrs=[0,35])
df.head()

In [None]:
df = get_dfs("refine_from_bad_static_heart", "refine_from_static_heart", "refine_from_lfd_heart", "refine_from_mednlarge_beads", "refine_from_medium_beads", pred_nrs=[0, 1, 10, 98])
df["network"] = df.run_name
mask = df.pred_nr >= 0
df.loc[mask, "network"] = df[mask].apply(lambda row: f"refinement step: {row.pred_nr}", axis=1)

In [None]:
# nbins = 8
# df['z_bin'] = pandas.cut(df['z'], bins=nbins, labels=numpy.arange(nbins))
# df.head()

In [None]:
df_filtered = df
# df_filtered = df[df.pred_nr < 5]
df_filtered.z.min()

In [None]:
def plot_scans_grid(metric: str):
    seaborn.set_style("darkgrid", {"axes.facecolor": ".8"})
    seaborn.set_context("talk")  # paper, notebook, talk, poster
    cmap_name = "viridis"
    g = seaborn.relplot(
        x="z",
        y=metric,
        hue="run_name",
        legend="brief",  # brief, full
        row="pred_nr",
        col="swipe_through",
#         palette=cmap_name,
        height=4, aspect=2,
        kind="scatter",
        data=df_filtered
    )
    g.map(plt.axvline, x=25, color=".7", dashes=(2, 1), zorder=0).set_axis_labels("z [μm]", "MS-SSIM").set_titles("train step: {row_name} | Swipe-through {col_name}")#.tight_layout(w_pad=0)
    # g.add_legend()
#     for ax in g.axes:
#         g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z', ax=ax)
#         ax.set_xlim([-z_range_value * 1.1, z_range_value * 1.1])
#         ax.set_ylim([-z_range_value * 1.1, z_range_value * 1.1])
#         ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c=".3")
#     g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z')
#     g.fig.axes[0].set_xlim(0, 9399*0.025)
    g.fig.tight_layout()
    root = Path("refine_lfd_training_plots")
    root.mkdir(exist_ok=True)
    g.fig.savefig(root / f"{metric}.png")
    
plot_scans_grid("ms_ssim-scaled")

In [None]:
# plot_scans_grid("ms_ssim-scaled")
plot_scans_grid("ssim-scaled")
plot_scans_grid("nrmse-scaled")
plot_scans_grid("psnr-scaled")
plot_scans_grid("mse_loss-scaled")
plot_scans_grid("smooth_l1_loss-scaled")

In [None]:
# def plot_scans(metric: str):
#     seaborn.set_style("darkgrid", {"axes.facecolor": ".8"})
#     seaborn.set_context("talk")  # paper, notebook, talk, poster
#     cmap_name = "viridis"
#     g = seaborn.relplot(x="time [s]", y=metric, hue="z", legend=False,
#                     palette=cmap_name, height=7, aspect=7,
#                     kind="scatter", data=df_filtered)
#     g.fig.colorbar(matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=z_offset, vmax=df["z"].max()+z_offset, clip=False), cmap=cmap_name), label='z')
# #     g.fig.axes[0].set_xlim(0, 9399*0.025)
#     g.fig.tight_layout()
#     root = Path("refine_lfd_training_plots")
#     root.mkdir(exist_ok=True)
#     g.fig.savefig(root / f"{metric}.png")
#
# plot_scans("ms_ssim-scaled")

In [None]:
# plot_scans("ms_ssim-scaled")
# plot_scans("ssim-scaled")
# plot_scans("nrmse-scaled")
# plot_scans("psnr-scaled")
# plot_scans("mse_loss-scaled")
# plot_scans("smooth_l1_loss-scaled")

In [None]:
# df_bins = df.groupby(["z_bin", "pred_nr"]).mean().reset_index()
# df_bins.head()
# g = seaborn.catplot(x="pred_nr", y="ms_ssim-scaled", hue="pred_nr",
#                 capsize=.2, palette="YlGnBu_d", height=3, aspect=1.0,
#                 kind="point", data=df)

In [None]:
plot_scans_grid("ms_ssim-scaled")