In [None]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
import os
import torch
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.net import Net
import awesome
from awesome.util.path_tools import get_project_root_path
from awesome.util.logging import basic_config
import matplotlib.pyplot as plt
from awesome.analytics.result_model import ResultModel
from awesome.analytics.noisy_unaries_result_model import NoisyUnariesResultModel
from awesome.run.functions import get_result, split_model_result, plot_image_scribbles, plot_mask_labels
from awesome.util.temporary_property import TemporaryProperty
from awesome.run.functions import get_result, split_model_result,register_alpha_map, plot_image_scribbles, plot_mask_labels, plot_mask
import numpy as np
from matplotlib.colors import to_hex, to_rgb
import matplotlib

%matplotlib inline
from tqdm.auto import tqdm
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from awesome.run.functions import get_mpl_figure
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.se import SE
from awesome.measures.ae import AE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.wrapper_module import WrapperModule
#load_ext matplotlib
#matplotlib tk
import normflows as nf
basic_config()

os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

In [None]:
from awesome.analytics.result_comparison import ResultComparison

paths = [
          "./runs/fbms_local/eval/unet/noisy_spatio_temporal"
         ]
models = []

for path in paths:
    for folder in os.listdir(path):
        if folder.startswith("old") or folder.startswith("log"):
            continue
        model = NoisyUnariesResultModel.from_path(os.path.join(path, folder))
        models.append(model)

import re
p = r"#?(?P<cfg_num>\d+)?_?(?P<net>[A-z0-9]+)_?(?P<feat>[\w\+\-]*)\_(?P<date>\d{2}_\d{2}_\d{2})\_(?P<time>\d{2}_\d{2}_\d{2})"
pattern = re.compile(p)

for model in models:
    match = pattern.fullmatch(model.name)
    model_name = None
    feat = []
    if match:
        model_name = match.group('net').strip("_")
        features = match.group('feat')
        if features is not None and features != "":
            feat = features.strip("+").split("+")
            if not any(["seed" in x for x in feat]):
                feat.append("seed42")
    else:
        print('No match for', model.name)
    model_name = model_name.replace("NET", "Net")
    model.display_name = model_name + " " + " ".join(feat)
    model.features = list(feat)
    model.config.result_directory = "final_mask"
    model.save_config()


# Resort the models by name to get a meaningful table order

_order = []

models = sorted(models, key=lambda m: _order.index(m.name) if m.name in _order else 0)

comparison = ResultComparison(models)
comparison.assign_numbers(force=True)

os.environ['PLOT_OUTPUT_DIR'] = comparison.output_folder

save_args = dict(transparent=False, save=True, dpi=300, ext=["png", "pdf"])

models

In [None]:
from typing import Any, Dict, Tuple
import pandas as pd
metrics = [
    "eval/epoch/MeanForegroundBinaryMIOU" ,
    "eval/epoch/MeanPriorForegroundBinaryMIOU",
    "eval/epoch/MeanPixelAccuracy",
    "eval/epoch/MeanPriorPixelAccuracy",
]

col_mapping = {
    "eval/epoch/MeanForegroundBinaryMIOU": "IoU",
    "eval/epoch/MeanPixelAccuracy": "Acc.",
    "eval/epoch/MeanPriorPixelAccuracy" : "Prior Acc.",
    "eval/epoch/MeanPriorForegroundBinaryMIOU": "Prior IoU" 
}

index_mapping = {
    0: "Baseline"
}

new_colmapping = {}
for k, v in dict(col_mapping).items():
    for k_i, v_i in dict(index_mapping).items():
        new_colmapping[k + "_" + str(k_i)] = v + " " + v_i

col_mapping = new_colmapping

def extract_features(model: ResultModel) -> Dict[str, Any]:
    res = dict()
    
    res['joint'] = "joint" in model.features
    
    res['model_name'] = model.config.name.split(" ")[0]
    model_features = list(model.features)
    
    if res['joint']:
        model_features.remove("joint")
    
    seed = next((x for x in model_features if "seed" in x), None)
    if seed is None:
        seed = model.run_config.seed
        res['seed'] = seed
    else:
        model_features.remove(seed)
        res['seed'] = int(seed.replace("seed", ""))

    noise_percentage = next((x for x in model_features if "np" in x), None)
    model_features.remove(noise_percentage)
    res['noise_percentage'] = float(noise_percentage.replace("np", "").replace("_", "."))

    if "REFIT" in model_features:
        model_features.remove("REFIT")
        res['prior'] = "refit"

    if "original" in model_features:
        model_features.remove("original")
        res['prior'] = "original"
    if "retrain" in model_features:
        model_features.remove("retrain")
        res['prior'] = "refit"
    if "retrain_xy" in model_features:
        model_features.remove("retrain_xy")
        res['prior'] = "refit"

    elif "convex" in model_features:
        model_features.remove("convex")
        res['prior'] = "convex"
    elif "diffeo" in model_features:
        model_features.remove("diffeo")
        res['prior'] = "diffeo"
    else:
        res['prior'] = "none"

    if "only_prior" in model_features:
        model_features.remove("only_prior")
        res['prior'] = res['prior'] + "+only_prior"

    if "all_frames" in model_features:
        model_features.remove("all_frames")
        res['prior'] = res['prior'] + "+all_frames"

    if "deeper" in model_features:
        model_features.remove("deeper")
        res['prior'] = res['prior'] + "+deeper"
    if "spatio-temporal" in model_features:
        model_features.remove("spatio-temporal")
        res['prior'] = res['prior'] + "+spatio-temporal"
    if "noisy" in model_features:
        model_features.remove("noisy")
        res['prior'] = res['prior'] + "+noisy"
    if "realnvp" in model_features:
        model_features.remove("realnvp")
        res['prior'] = res['prior'] + "+realnvp"


    dataset_name = model_features.pop(0)
    res['dataset_name'] = dataset_name

    assert len(model_features) == 1, f"Multiple features {model_features} in model {model.output_path}"
    res['feature_type'] = model_features[0]

    return res

df = comparison.metric_table(metrics, 
                             ref="all", 
                             mode="max",
                        formatting=False)

df = df.reset_index()

def extract_ft_row(row: pd.Series) -> Tuple[str, bool, str, int]:
    name = row['index']
    model = [m for m in models if m.name == name][0]
    res = extract_features(model)
    return (res['model_name'], res['joint'], res['feature_type'], res['seed'], res['dataset_name'], res['noise_percentage'], res['prior'])

df[['model_name', 'joint', 'feature_type', 'seed', 'dataset_name', 'noise_percentage', "prior"]] = df.apply(extract_ft_row, axis=1, result_type="expand")

df = df[['index', 'model_name', 'dataset_name'] + list(col_mapping.keys()) + ['joint', 'feature_type', 'seed', 'noise_percentage', "prior"]]

grouped = df.groupby(['model_name', 'joint', 'feature_type'])

display_df = df.rename(columns=col_mapping)
display(display_df)

In [None]:
from typing import Optional
from awesome.measures.miou import MIOU
miou = MIOU()
def compute_overlap_iou_for_model(result_model: NoisyUnariesResultModel, 
                                  df: pd.DataFrame, indices_tqdm: Optional[tqdm] = None, 
                                  map_location: Optional[torch.device] = None):
    runner = result_model.get_runner(0)
    model = runner.agent._get_model()
    dataloader = runner.agent.training_dataset
    model_gets_targets = runner.agent.model_gets_targets
    indices = list(range(len(dataloader)))
    noisy_unaries_dict = result_model.get_noisy_unaries_dict(map_location=map_location)

    if 'noisy_unaries_length' not in df.columns:
        df['noisy_unaries_length'] = -1

    if 'noisy_unaries_index' not in df.columns:
        df['noisy_unaries_index'] = None
        df['noisy_unaries_index'] = df['noisy_unaries_index'].astype(object)

    if 'noisy_unaries_overlap_ious' not in df.columns:
        df['noisy_unaries_overlap_ious'] = None
        df['noisy_unaries_overlap_ious'] = df['noisy_unaries_overlap_ious'].astype(object)

    if 'noisy_unaries_index_was_noise' not in df.columns:
        df['noisy_unaries_index_was_noise'] = None
        df['noisy_unaries_index_was_noise'] = df['noisy_unaries_index_was_noise'].astype(object)

    if 'noisy_unaries_overlap_mean_iou' not in df.columns:
        df['noisy_unaries_overlap_mean_iou'] = -1

    miou = MIOU(invert=True)

    match_idx = np.argwhere((df['index'] == result_model.name).values)[0][0]

    df.at[match_idx, "noisy_unaries_length"] = len(noisy_unaries_dict)
    df.at[match_idx, "noisy_unaries_index"] = list(noisy_unaries_dict.keys())

    ious = torch.zeros(len(indices))
    was_noise = torch.zeros(len(indices), dtype=torch.bool)

    if indices_tqdm is None:
        indices_tqdm = tqdm(total=len(indices), desc="Images")
    else:
        indices_tqdm.reset(total=len(indices))

    for i in indices:
        res, ground_truth, img, fg, bg = get_result(model, dataloader, i, model_gets_targets=model_gets_targets)
        res = split_model_result(res, model, dataloader, img)
        res_prior = res.get("prior", None)
        res_pred = res["segmentation"]
        boxes = res.get("boxes", None)
        labels = res.get("labels", None)

        was_noise[i] = noisy_unaries_dict.get(i, None) is not None 
        ious[i] = miou(res_prior, res_pred)
        indices_tqdm.update()

    df.at[match_idx, "noisy_unaries_overlap_ious"] = ious.numpy()
    df.at[match_idx, "noisy_unaries_index_was_noise"] = was_noise.numpy()
    df.at[match_idx, "noisy_unaries_overlap_mean_iou"] = ious.mean().numpy()
    return indices_tqdm

In [None]:
indices_tqdm = None
for model in tqdm(models, desc="Computing noisy unaries overlap IoU"):
    indices_tqdm = compute_overlap_iou_for_model(model, df, indices_tqdm=indices_tqdm, map_location=torch.device("cpu"))
indices_tqdm.close()

In [11]:
import pickle

from awesome.util.path_tools import numerated_file_name
path = "./output/noisy_unaries_df.pkl"
def save_df(path, df):
    path = numerated_file_name(path)
    with open(path, "wb") as f:
        pickle.dump(df, f)
    return path

def load_df(path):
    with open(path, "rb") as f:
        return pickle.load(f)



In [13]:
df = load_df("./output/noisy_unaries_df_tower.pkl")
df

Unnamed: 0,index,model_name,dataset_name,eval/epoch/MeanForegroundBinaryMIOU_0,eval/epoch/MeanPixelAccuracy_0,eval/epoch/MeanPriorPixelAccuracy_0,eval/epoch/MeanPriorForegroundBinaryMIOU_0,joint,feature_type,seed,noise_percentage,prior,noisy_unaries_length,noisy_unaries_index,noisy_unaries_overlap_ious,noisy_unaries_index_was_noise,noisy_unaries_overlap_mean_iou
0,#00_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.976329,0.613105,False,edge,131,0.3,diffeo+only_prior+spatio-temporal+noisy+realnvp,6,"[7, 5, 9, 17, 4, 16]","[0.44331086, 0.7048556, 0.82425034, 0.68361485...","[False, False, False, False, True, True, False...",0.509002
1,#01_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.97926,0.693721,False,edge,131,0.6,diffeo+only_prior+spatio-temporal+noisy+realnvp,11,"[7, 5, 9, 17, 4, 16, 2, 14, 3, 1, 6]","[0.7659068, 0.15190026, 0.08374349, 0.07879557...","[False, True, True, True, True, True, True, Tr...",0.343454
2,#02_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.942786,0.0,False,edge,47,0.5,diffeo+only_prior+spatio-temporal+noisy+realnvp,10,"[13, 16, 12, 5, 3, 2, 6, 10, 11, 15]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[False, False, True, True, False, True, True, ...",0.0
3,#03_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.985247,0.737333,False,edge,42,0.4,diffeo+only_prior+spatio-temporal+noisy+realnvp,8,"[1, 2, 6, 16, 12, 15, 9, 14]","[0.85892284, 0.13685334, 0.19525188, 0.5620234...","[False, True, True, False, False, False, True,...",0.500875
4,#04_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.977419,0.551204,False,edge,47,0.4,diffeo+only_prior+spatio-temporal+noisy+realnvp,8,"[13, 16, 12, 5, 3, 2, 6, 10]","[0.7206597, 0.71501505, 0.268971, 0.59363806, ...","[False, False, True, True, False, True, True, ...",0.52755
5,#05_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.990259,0.823874,False,edge,42,0.3,diffeo+only_prior+spatio-temporal+noisy+realnvp,6,"[1, 2, 6, 16, 12, 15]","[0.873004, 0.13800967, 0.19436826, 0.8034708, ...","[False, True, True, False, False, False, True,...",0.62283
6,#06_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.989506,0.843188,False,edge,47,0.2,diffeo+only_prior+spatio-temporal+noisy+realnvp,4,"[13, 16, 12, 5]","[0.9456474, 0.9387551, 0.92808884, 0.93275124,...","[False, False, False, False, False, True, Fals...",0.832793
7,#07_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.981817,0.677422,False,edge,131,0.4,diffeo+only_prior+spatio-temporal+noisy+realnvp,8,"[7, 5, 9, 17, 4, 16, 2, 14]","[0.734954, 0.8768719, 0.8202636, 0.76039326, 0...","[False, False, True, False, True, True, False,...",0.514166
8,#08_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.961171,0.396629,False,edge,131,0.5,diffeo+only_prior+spatio-temporal+noisy+realnvp,10,"[7, 5, 9, 17, 4, 16, 2, 14, 3, 1]","[0.7454195, 0.09768665, 0.09218179, 0.20142125...","[False, True, True, True, True, True, False, T...",0.120576
9,#09_UNET+cars3+edge+diffeo+only_prior+realnvp+...,UNet,cars3,0.779106,0.985758,0.988748,0.820842,False,edge,42,0.1,diffeo+only_prior+spatio-temporal+noisy+realnvp,2,"[1, 2]","[0.8690517, 0.13842389, 0.09811182, 0.69073164...","[False, True, True, False, False, False, False...",0.719126


In [14]:
import pickle

save_path = "./output/noisy_unaries_df.pkl"
with open(save_path, "wb") as f:
    pickle.dump(df, f)

In [None]:
group_by = ['model_name', 'dataset_name', 'joint', 'feature_type', 'noise_percentage', "prior"]

grouped = df.groupby(group_by)
for k, v in grouped:
    print(k)
    display(v)

In [None]:
#use_cols = list(set(df.columns) - set(["noisy_unaries_index", "noisy_unaries_overlap_ious", "noisy_unaries_index_was_noise"]))


group_by = ['model_name', 'dataset_name', 'joint', 'feature_type', 'noise_percentage', "prior"]

use_cols = group_by + ["noisy_unaries_overlap_mean_iou"] 



grouped = df[use_cols].groupby(group_by)
for k, v in grouped:
    print(k)
    display(v)
    break

mean_vals = grouped.mean()
min_vals = grouped.min()
max_vals = grouped.max()
std_vals = grouped.std()

display(mean_vals)


In [None]:

from awesome.util.matplotlib import saveable


@saveable()
def plot_noise_bar(df, 
                   column: str, 
                   bar_color: Optional[Any] = None,
                   size: float = 7):

        mean_vals = df.mean()
        min_vals = df.min()
        max_vals = df.max()
        std_vals = df.std()

        fig, ax = get_mpl_figure(1, 1, size=size)
        percentages = [x[4] for x in list(mean_vals.index)]
        min_max_err = np.stack([np.abs(min_vals[column].values - mean_vals[column].values), 
                                max_vals[column].values - mean_vals[column].values])
        x = percentages
        y = mean_vals[column]
        ax.bar(x=x, 
               height=y, 
               color=bar_color,
               width=0.05
                )
        ax.errorbar(x, y, yerr=min_max_err, label="Mean IoU", fmt="", linestyle='', color="black",capthick=1)

        ax.xaxis.set_major_formatter("{:.0%}".format)

        ax.set_ylabel("IoU")
        #ax.xaxis.set_major_formatter(lambda x: ".2f".format(x))

        ax.set_xlabel("Label Noise")
        ax.grid(axis="y", linestyle='-', linewidth=1)
        return fig

column = "noisy_unaries_overlap_mean_iou"

path = "output/noisy_spatio_temporal/bar_plot"
fig = plot_noise_bar(grouped, column, bar_color=plt.get_cmap("tab10")(1), override=True, save=True, path=path, ext=["png" , "pdf"])

In [58]:
mean_grouped = grouped.mean()
mgrp_df = mean_grouped.reset_index()
mgrp_df=mgrp_df[["noise_percentage", column]]

mgrp_df["noise_percentage"] = mgrp_df["noise_percentage"].apply("{:.0%}".format).apply(lambda x: x.replace("%", " \%"))
mgrp_df[column] = mgrp_df[column].apply(lambda x: str(round(x, 3)))

col_mapping = {
    "noise_percentage": "Label Noise",
    column: "IoU"
}
mgrp_df = mgrp_df.rename(columns=col_mapping)
display(mgrp_df)

print(mgrp_df.to_latex())

Unnamed: 0,Label Noise,IoU
0,0 \%,0.826
1,10 \%,0.799
2,20 \%,0.796
3,30 \%,0.527
4,40 \%,0.514
5,50 \%,0.186
6,60 \%,0.189


\begin{tabular}{lll}
\toprule
 & Label Noise & IoU \\
\midrule
0 & 0 \% & 0.826 \\
1 & 10 \% & 0.799 \\
2 & 20 \% & 0.796 \\
3 & 30 \% & 0.527 \\
4 & 40 \% & 0.514 \\
5 & 50 \% & 0.186 \\
6 & 60 \% & 0.189 \\
\bottomrule
\end{tabular}



In [111]:
from awesome.run.functions import count_parameters

def get_model_size(model: ResultModel) -> int:
    runner = model.get_runner(0)
    agent = runner.agent
    model = agent._get_model()
    segm = pd.DataFrame(count_parameters(model.segmentation_module))
    prior = pd.DataFrame(count_parameters(model.prior_module))
    return segm, prior

get_model_size(models[0])

In [None]:
np.stack([min_vals['noisy_unaries_overlap_mean_iou'].values - mean_vals['noisy_unaries_overlap_mean_iou'].values, max_vals['noisy_unaries_overlap_mean_iou'].values - mean_vals['noisy_unaries_overlap_mean_iou'].values])

In [None]:
min_vals['noisy_unaries_overlap_mean_iou'].values

In [None]:
def plot_result(result_model, indices):
    runner = result_model.get_runner(0)
    model = runner.agent._get_model()
    dataloader = runner.agent.training_dataset
    model_gets_targets = runner.agent.model_gets_targets
    noisy_unaries_dict = result_model.get_noisy_unaries_dict()

    for i in indices:
        res, ground_truth, img, fg, bg = get_result(model, dataloader, i, model_gets_targets=model_gets_targets)
        res = split_model_result(res, model, dataloader, img)
        res_prior = res.get("prior", None)
        res_pred = res["segmentation"]
        boxes = res.get("boxes", None)
        labels = res.get("labels", None)

        fig = plot_image_scribbles(image=img, inference_result=res_pred, foreground_mask=fg, background_mask=bg, 
                                   prior_result=res_prior)
        display(fig)
        plt.close(fig)

plot_result(models[1], range(18))

In [None]:
group_by = ['model_name', 'dataset_name', 'joint', 'feature_type', 'noise_percentage', "prior"]

grouped = df.groupby(group_by)
for k, v in grouped:
    print(k)
    display(v)

# Group by seed

In [None]:
import numpy as np
from awesome.run.functions import get_mpl_figure, get_result, split_model_result
from awesome.model.path_connected_net import PathConnectedNet
from typing import Any
plt.close("all")

grid_shapes = dict()
model = runner.agent._get_model()
dataloader = runner.dataloader


index = list(range(0, len(dataloader), 1))

t_n = len(index)
t_max = len(dataloader) - 1

images = []
segmentations = []
priors_no_sig = []


with TemporaryProperty(model, use_prior_sigmoid=False), TemporaryProperty(dataloader, do_image_blurring=False):
    for i in index:
        res_no_sig, ground_truth, img, _, _ = get_result(model, dataloader, i, False)
        res_no_sig = split_model_result(res_no_sig, model, dataloader, img)

        priors_no_sig.append(res_no_sig.get("prior_raw", None))

        res_pred = res_no_sig["segmentation"]

        images.append(img)
        segmentations.append(res_pred)

images = torch.stack(images)
segmentations = torch.stack(segmentations)
priors_no_sig = torch.stack(priors_no_sig)

shp = priors_no_sig.shape[-2:]
if shp not in grid_shapes:
    grid_shapes[shp] = PathConnectedNet.create_normalized_grid(shp).cpu().numpy()
grid = grid_shapes[priors_no_sig.shape[-2:]]




# Stack time 
pred = priors_no_sig # B x C x H x W
# Spatio temporal grid
t_grid = torch.stack([torch.cat([torch.tensor(grid[0]), torch.full((1, *pred.shape[-2:]), t / t_max)], dim=0) for t in index])


def plot_spatio_temporal_object(grid: Any, unaries: Any, size: float = 5):
    
    if isinstance(grid, torch.Tensor):
        grid = grid.cpu().numpy()
    if isinstance(unaries, torch.Tensor):
        unaries = unaries.cpu().numpy()
    
    if len(grid.shape) < 4:
        grid = grid[None]
    if len(unaries.shape) < 4:
        unaries = unaries[None]
    

    fig, ax = get_mpl_figure(subplot_kw=dict(projection='3d'))

    for i in range(grid.shape[0]):
        g = grid[i]
        u = unaries[i][0]

        z = u
        y = g[1]
        x = g[0]
        offset = g[2].max() # Offset is the time
        ax.contour(x, y, z, levels=[0.5], colors="red", offset=offset, linewidths=2)

    x_left, x_right = ax.get_xlim()
    y_low, y_high = ax.get_ylim()

    zoom= 1
    elevation = 130
    azimuth = 90
    roll = 0

    ax.set_box_aspect(aspect=((x_right-x_left)/(y_low-y_high), 1, 1), zoom=zoom)
    ax.view_init(elev=elevation, azim=azimuth, roll=roll)

    ax.invert_zaxis()

    #ax.set_axis_off()
    return fig

#fig = plot_spatio_temporal_object(t_grid, pred)
#fig


In [None]:
from awesome.run.functions import plot_3d_tubes
import itertools

subsamplings = [3, 4, 6]
top_image_alphas = [0, 0.2]


for subsampling, top_image_alpha in itertools.product(subsamplings, top_image_alphas):
    path = f"./output/spatio_temporal/tubes_{runner.dataloader.__dataset__.dataset_name}_subs_{subsampling}_alpha_{str(top_image_alpha).replace('.', '_')}"
    fig = plot_3d_tubes(priors_no_sig, images, 
                        top_image_alpha=top_image_alpha,
                        subsample_factor=subsampling,
                        subsample_image_mode="grid_sample",
                        grid_sample_mode="nearest",
                        transparent=True,
                        path=path, save=True, ext=["png", "pdf"], override=True)
    display(fig)
    plt.close(fig)

In [None]:
from awesome.run.functions import plot_3d_tubes

path = f"./notebooks/output/spatio_temporal/tubes_no_sub_{runner.dataloader.__dataset__.dataset_name}"
fig = plot_3d_tubes(priors_no_sig, images, subsample_x=1, subsample_y=1, path=path, save=True, ext=["png", "pdf"], override=True)
display(fig)

plt.close(fig)


In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.colors import to_rgba
from skimage import measure

plt.close("all")


subsample_x = 6
subsample_y = 6


def plot_3d_tubes(logits, images, subsample_x = 6, subsample_y = 6):
    vol = priors_no_sig[:, 0, ::subsample_y, ::subsample_x].cpu().numpy()

    # Use marching cubes to obtain the surface mesh of these ellipsoids
    verts, faces, normals, values = measure.marching_cubes(vol.T, 0)

    # Display resulting triangular mesh using Matplotlib. This can also be done
    # with mayavi (see skimage.measure.marching_cubes docstring).
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    image_index = 18
    rgb_img = images[image_index, :, ::subsample_y, ::subsample_x].permute(1, 2, 0).numpy()

    x = np.arange(0, rgb_img.shape[-2], 1)
    y = np.arange(0, rgb_img.shape[-3], 1)
    xx, yy = np.meshgrid(x, y)


    facecolor = plt.get_cmap("tab10")(1)
    edgecolor = None

    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], shade=True, facecolors=facecolor, edgecolor='none', linewidth=0, antialiased=False,)
    ax.add_collection3d(mesh)


    def create_grid_verticies(grid, z_loc):
        indices = np.argwhere(grid.T >= 0)
        repeated = indices[:, None, :].repeat(4, axis=1)
        repeated[:, 1, 1] += 1 # Added one to x
        repeated[:, 2, :] += 1 # Added one to x and y
        repeated[:, 3, 0] += 1 # Added one to y
        # Add constant z
        zz = np.zeros_like(repeated[:, :, 0])[..., None]
        zz.fill(z_loc)
        vertices_img = np.concatenate([repeated, zz], axis=-1)
        return vertices_img


    fcol = rgb_img.reshape((math.prod(rgb_img.shape[:-1]), 3), order="F")
    secmesh = Poly3DCollection(create_grid_verticies(xx, image_index), shade=True, facecolors=fcol, edgecolor='none', linewidth=0, antialiased=False)
    ax.add_collection3d(secmesh)

    front_index = 0
    front_img = images[front_index, :, ::subsample_y, ::subsample_x].permute(1, 2, 0).numpy()

    fcol_front = front_img.reshape((math.prod(front_img.shape[:-1]), 3), order="F")
    # Stack alpha
    fcol_front = np.concatenate([fcol_front, np.zeros_like(fcol_front[:, 0])[..., None]], axis=-1)

    # Set alpha to 0 for all pixels that are not in the foreground
    subsample_df_prior = priors_no_sig[front_index, :, ::subsample_y, ::subsample_x][0].numpy()
    fg = (subsample_df_prior <= 0).reshape((math.prod(subsample_df_prior.shape), 1), order="F")
    fcol_front[fg[:, 0], 3] = 1

    thirdmesh = Poly3DCollection(create_grid_verticies(xx, front_index), shade=True, facecolors=fcol_front, edgecolor='none', linewidth=0, antialiased=False)
    ax.add_collection3d(thirdmesh)


    #ax.plot_surface(xx, yy, np.zeros_like(xx), facecolors=rgb_img, rcount=rgb_img.shape[-2], ccount=rgb_img.shape[-3], zorder=0)


    #secmesh = Poly3DCollection(verts[faces] + np.array([[0, 0, 10]]), shade=True, facecolors=facecolor, edgecolor=edgecolor, zorder=1)
    #ax.add_collection3d(secmesh)


    #y_max, x_max = xx.shape

    # # Create axis like arrows for x, y and z
    # vec_len = 5
    # x_arrow = np.array([[0, y_max, 0], [vec_len, 0, 0]])
    # y_arrow = np.array([[0, y_max, 0], [0, -vec_len, 0]])
    # z_arrow = np.array([[0, y_max, 0], [0, 0, vec_len]])

    # arrow_starts = np.stack([x_arrow[0], 
    #                     y_arrow[0],
    #                     z_arrow[0]])
    # arrow_directions = np.stack([
    #                     x_arrow[1],
    #                     y_arrow[1],
    #                     z_arrow[1]])

    # base_length_ratio = (len(images) - 1) * 0.1

    # for i in range(len(arrow_starts)):
    #     len_vec = np.linalg.norm(arrow_directions[i])
    #     ax.quiver(arrow_starts[i, 0], arrow_starts[i, 1], arrow_starts[i, 2], 
    #             arrow_directions[i, 0], arrow_directions[i, 1], arrow_directions[i, 2], color="black", zorder=10, 
    #             normalize=True,
    #              linewidths=2)



    #light_source = mpl.colors.LightSource(azdeg=315, altdeg=10)


    ax.set_xlim(0, vol.shape[-1])  
    ax.set_ylim(0, vol.shape[-2]) 
    ax.set_zlim(0, vol.shape[-3])  

    #ax.set_xlabel('X')
    #ax.set_ylabel('Y')
    ax.set_zlabel('Time [t]')

    ax.view_init(elev=40, azim=90, roll=0)

    ax.invert_xaxis()
    ax.invert_zaxis()

    ax.grid(False)
    ax.xaxis.line.set_lw(0.)
    ax.set_xticks([])

    ax.yaxis.line.set_lw(0.)
    ax.set_yticks([])

    t = np.arange(0, len(images), 1)[::5]
    ax.set_zticks(t, labels=t.astype(int))

    ax.set_aspect('equalxy')

    return fig
#ax.set_axis_off()
#plt.tight_layout()
plt.show()

In [None]:

ax.get_zticks()

In [None]:
np.moveaxis(rgb_img, 2, 0).shape

In [None]:
rgb_img

In [None]:
verts[faces[0][None,]].shape

In [None]:
verts.shape

In [None]:
indices = np.argwhere(xx[:-1, :-1] >= 0)

In [None]:
indices = np.argwhere(xx[:-1, :-1] >= 0)
repeated = indices[:, None, :].repeat(4, axis=1)
repeated[:, 1, 1] += 1 # Added one to x
repeated[:, 2, :] += 1 # Added one to x and y
repeated[:, 3, 0] += 1 # Added one to y

# Add constant z
zz = np.zeros_like(repeated[:, :, 0])[..., None]
vertices_img = np.concatenate([repeated, zz], axis=-1)



zz.shape

In [None]:
images.shape

In [None]:
vol.shape

In [None]:
vol.min()

In [None]:
ellip_double.shape

In [None]:
pred.min()

In [None]:
verts.shape

In [None]:
len(np.argwhere(vol == 1))

In [None]:
ax = plt.figure().add_subplot(projection='3d')

#vol = ellip_double
#vol = 1 - pred[:, 0, ::6, ::6].cpu()
ax.voxels(vol <= 0, facecolors="red")

plt.show()

In [None]:
ellip_double

In [None]:
(pred == 0)

In [None]:
y.shape

In [None]:
t_grid

In [None]:
import cv2 as cv
from awesome.run.functions import plot_as_image


p = pred[0][0]

all_contours = []
all_hierarchy = []

times = [t / t_max for t in index]

for i in range(0, len(pred)):
    p = pred[i][0]
    t = times[i]
    ret, thresh = cv.threshold(((1 - p.numpy()) * 255).astype(np.uint8), 123, 1, cv.THRESH_BINARY)
    contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)

    # We are loosing some information here, eg. holes in the object are not represented
    # This is because we are only using the contour of the object

    local_contours = [np.concatenate([c[:, 0, :] / (torch.tensor(p.shape[-2:]).numpy() - 1), np.full((c.shape[0], 1), t)], axis=1) for c in contours]
    all_contours.append(local_contours)
    all_hierarchy.append(hierarchy)

#points = np.concatenate(all_contours)

In [None]:
all_triangles = []

for idx_time_frame in range(len(all_contours)):
    last_contour_frame = None
    if idx_time_frame > 0:
        last_contour_data = all_contours[idx_time_frame - 1]
    else:
        last_contour_data = None
    current_contour_data = all_contours[idx_time_frame]

    if last_contour_data is None:
        # First frame, leave it nonclosed
        continue

    for idx_contour in range(len(current_contour_data)):
        current_contour = current_contour_data[idx_contour]
        # Find the closest point distance between the current contour and one of the last contours
        min_contour_idx = None
        min_contour_dist = np.inf
        min_dot_current = None
        min_dot_last = None
        for idx_last_contour in range(len(last_contour_data)):
            last_contour = last_contour_data[idx_last_contour]

            min_mat = np.linalg.norm(current_contour[:, None, :] - last_contour[None, :, :], axis=-1)
            arg_idx = np.argmin(min_mat)
            dist = min_mat[arg_idx]

            if dist < min_contour_dist:
                min_contour_dist = dist
                min_contour_idx = idx_last_contour
                min_dot_current = arg_idx[0]
                min_dot_last = arg_idx[1]
        
        if min_contour_idx is None:
            # No contour found
            continue
        
        # We found a contour which is closest, now we need to triangulate
        last_contour = last_contour_data[min_contour_idx]
        





In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

x_offset = 0.1

test_points = np.array([
    [0 + x_offset, 0, 0],
    [1 + x_offset, 1, 0],
    [0.5 + x_offset, 0, 0],
])

# Plot the surface
ax.plot_trisurf(test_points[:, 0], test_points[:, 1], test_points[:, 2])

# Set an equal aspect ratio
ax.set_aspect('equal')

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()

In [None]:
%matplotlib widget

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import Delaunay, ConvexHull

plt.close("all")
fig = plt.figure()
ax = fig.add_subplot(projection='3d')


tri = None

ctr = []

simplices_list = []

for i in range(1, len(all_contours)):
    sublist = all_contours[i]
    for item in sublist:
        if tri is None:
            # Get all points from the first frame
            pts = []
            for j in range(0, len(all_contours[i-1])):
                pts.append(all_contours[i-1][j])
            pts.append(item)
            first_set = np.concatenate(pts)
            tri = Delaunay(first_set, incremental=True)
            #tri = "test"
            ctr.append(first_set)
        else:
            tri.add_points(item)
            ctr.append(item)
    # Save the simplices
    simplices_list.append(tri.simplices)
    # reset tri
    tri = None

points = np.concatenate(ctr)

simplices = np.concatenate(simplices_list)
#ax.plot_trisurf(points[:,0], points[:,1], points[:,2], triangles=simplices)
#for points, triangles in zip(ctr, simplices_list):
#    ax.plot_trisurf(points[:,0], points[:,1], points[:,2], triangles=triangles)

for i, points in enumerate(ctr):
    ax.plot(points[:,0], points[:,1], points[:,2], 'o', label='frame {}'.format(i))



ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

ax.legend()
plt.show()

In [None]:
import open3d as o3d

pcd = o3d.data.BunnyMesh()
print(pcd)
o3d.visualization.draw_geometries([pcd],
                                  zoom=0.664,
                                  front=[-0.4761, -0.4698, -0.7434],
                                  lookat=[1.8900, 3.2596, 0.9284],
                                  up=[0.2304, -0.8825, 0.4101])

In [None]:
dir(o3d.data)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# Make data
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = 10 * np.outer(np.cos(u), np.sin(v))
y = 10 * np.outer(np.sin(u), np.sin(v))
z = 10 * np.outer(np.ones(np.size(u)), np.cos(v))

# Plot the surface
ax.plot_surface(x, y, z)

# Set an equal aspect ratio
ax.set_aspect('equal')

plt.show()

In [None]:
plot_as_image(z)

In [None]:
import numpy as np
import pyvista as pv
ctr = []

for sublist in all_contours:
    for item in sublist:
        ctr.append(item)

points = np.concatenate(ctr)
cloud = pv.PolyData(points)
cloud.plot()

volume = cloud.delaunay_3d(alpha=[0.01, 0.01, 0.1])
shell = volume.extract_geometry()

#axes = pv.Axes()
#display(axes.show_actor())

shell.plot(show_axes=False)

In [None]:
%pip install trame-vuetify

In [None]:
fig, ax = get_mpl_figure()

x = contours[1].squeeze()[:, 0]
y = contours[1].squeeze()[:, 1]
ax.plot(x, y, color="red", linewidth=2)
fig

In [None]:
contours[0].squeeze()[:, 0]

In [None]:
t = torch.linspace(0, 1, 5)
torch.cat([res, torch.full(t)], dim=0) for t in 

In [None]:
grid.shape

## Old Experiments with glow

In [None]:
xytype = "edge"
dataset_kind = "train"
dataset = "bear01"
all_frames = True
subset = None #slice(0, 5)
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"


segmentation_model_state_dict_path = None
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"

prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )
dataset = AwesomeDataset(
    **{
        "dataset": real_dataset,
        "xytype": xytype,
        "feature_dir": f"{data_path}/Feat",
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False,
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
        "image_channel_format": image_channel_format,
        "do_image_blurring": True,
        "model_input_requires_grad": True,
        "subset": subset,
        "spatio_temporal": False,
    }
)


segmentation_model = UNet(4, 1)
segmentation_model.load_state_dict(torch.load(segmentation_model_state_dict_path))


def init_glow(channels: int, 
              hidden_channels: int,
              n_flows: int,
              height: int, 
              width: int,
              scale: bool = True,
              scale_map: Literal["sigmoid", "exp"] = "sigmoid",
              ) -> nf.NormalizingFlow:
    # Define flows

    input_shape = (channels, height, width)

    # Set up flows, distributions and merge operations
    q0 = nf.distributions.base.Uniform(input_shape, 0, 1)
    flows = []
    
    for j in range(n_flows):
        flows += [nf.flows.GlowBlock(channels, hidden_channels,
                                    split_mode='channel', 
                                    scale_map=scale_map, leaky=0.01,
                                    scale=scale, net_actnorm=False)]

    # Construct flow model with the multiscale architecture
    model = nf.NormalizingFlow(q0, 
                               flows, 
                               q0)

    return model


In [None]:
import logging
from awesome.agent.torch_agent import TorchAgent
from awesome.dataset.prior_dataset import PriorManager
from awesome.measures.unaries_weighted_loss import UnariesWeightedLoss
from awesome.model.wrapper_module import WrapperModule
from awesome.model.unet import UNet
from awesome.model.path_connected_net import PathConnectedNet
from awesome.model.convex_net import ConvexNextNet
from normflows import NormalizingFlow
from normflows.flows import GlowBlock
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from awesome.util.torch import TensorUtil



    
channels = 2

image_shape = dataset[0][0][0].shape[1:]

flow_model = init_glow(channels=channels, hidden_channels=256, n_flows=3, 
                       height=image_shape[0], 
                       width=image_shape[1],
                       scale=True)
convex_model = ConvexNextNet(n_hidden=130, 
                             n_hidden_layers=2,
                             in_features=channels)

path_connected_model = PathConnectedNet(convex_model, flow_model)

wrapper_module = WrapperModule(
    segmentation_module=segmentation_model,
    prior_module=path_connected_model,
    prior_arg_mode="param_clean_grid"
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

lr = 1e-3
weight_decay = 1e-7

previous_state = None
previous_center_of_mass = None

prior_module = wrapper_module.prior_module


use_prior_sigmoid = True
use_logger = False
use_step_logger = False
batch_progress_bar = None

criterion = UnariesConversionLoss(SE(reduction="mean"))

TensorUtil.to(wrapper_module, device=device)

In [None]:
inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(dataset[0], training_dataset=dataset)

max_iter = 1000

loss_hist = np.array([])

grid = inputs[2]

prior_module = wrapper_module.prior_module
model = prior_module.flow_net

optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5)

inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(dataset[0], training_dataset=dataset)

grid = inputs[2]

grid = grid.to(device)
model = model.to(device)

grid = grid[None,...]

model.train()

for i in tqdm(range(max_iter)):
    
    x, y = grid, grid
    
    optimizer.zero_grad()
    loss = model.forward_kld(x.to(device))
        
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer.step()

    loss_hist = np.append(loss_hist, loss.detach().to('cpu').numpy())
    del(x, y, loss)


In [None]:

num_epochs = 1000

data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
it = data_loader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_progress_bar = True

if use_progress_bar:
    it = tqdm(it, desc="Pretraining images")

for i, item in enumerate(it):
    inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(item, training_dataset=dataset)
    device_inputs: torch.Tensor = TensorUtil.to(
        inputs, device=device)
    # device_labels: torch.Tensor = TensorUtil.to(labels, device=device)

    # Evaluate model to get unaries
    # Switch prior weights if needed, using context manager
    with PriorManager(wrapper_module,
                        prior_state=prior_state,
                        prior_cache=dataset.__prior_cache__,
                        model_device=device,
                        training=True
                        ):
        
        unaries = None
        has_proper_prior_fit = False
        loaded_current_from_checkpoint = False

        # Get the unaries
        # Disable prior evaluation to just get the unaries
        with torch.no_grad(), TemporaryProperty(wrapper_module, evaluate_prior=False):
            if isinstance(device_inputs, list):
                unaries = wrapper_module(*device_inputs)
            else:
                unaries = wrapper_module(device_inputs)


        # Getting inputs for prior
        prior_args, prior_kwargs = wrapper_module.get_prior_args(device_inputs[0],
                                                                    *device_inputs[1:],
                                                                    segm=unaries[0, ...],
                                                                    )
        _input = prior_args[0]
        actual_input = _input.detach().clone()

        _unique_vals = torch.unique(unaries >= 0.5)
        # Check if unaries output contains at least some foreground
        if len(_unique_vals) == 1:
            # No foreground / background predicted. Skip this image
            # We will keep the state of the prior if reuse_state is True
            # If there was a pre existing state, we will use it again
            logging.warning(f"Unaries of segmentation model contain no foreground. Skipping image. {i}")
            continue
        
        # Determine number of epochs
        epochs = num_epochs
        
        # Train n iterations
        it = range(epochs)
        if use_progress_bar:
            desc = f'Image {i + 1}: Pretraining'
            if batch_progress_bar is None:
                batch_progress_bar = tqdm(
                    total=epochs,
                    desc=desc,
                    leave=True)
            else:
                batch_progress_bar.reset(total=epochs)
                batch_progress_bar.set_description(desc)

        groups = []
        groups += [dict(params=prior_module.flow_net.parameters(), weight_decay=weight_decay)]
        groups += [dict(params=prior_module.convex_net.parameters())]
        
        optimizer = torch.optim.Adam(groups, lr=lr)

        device_prior_output = None

        with torch.set_grad_enabled(True):
            # Train n iterations
            for step in it:
                optimizer.zero_grad()
                # Forward pass
                device_prior_output = prior_module(actual_input, *prior_args[1:], **prior_kwargs)
                device_prior_output = wrapper_module.process_prior_output(
                    device_prior_output, use_sigmoid=use_prior_sigmoid)[None, ...]  # Add batch dim again

                loss: torch.Tensor = criterion(
                    device_prior_output, unaries)

                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
                else:
                    logging.warning(
                        f"Loss is nan or inf. Skipping step {step} of image {i}")
                    break

                if use_logger and use_step_logger:
                    logger.log_value(
                        loss.item(), f"PretrainingLoss/Image_{i}", step=step)

                prior_module.enforce_convexity()
                if batch_progress_bar is not None:
                    batch_progress_bar.set_postfix(
                        loss=loss.item(), refresh=False)
                    batch_progress_bar.update()
                        



In [None]:

grid = inputs[2]

prior_module = wrapper_module.prior_module
norm_flow = prior_module.flow_net

with torch.no_grad():
    norm_flow.eval()
    grid = grid.to(device)
    grid = grid[None, ...]
    out_grid = norm_flow(grid)


out_grid.min()


In [None]:
res, ground_truth, img, fg, bg = get_result(wrapper_module, dataset, 0, model_gets_targets=False)
res = split_model_result(res, wrapper_module, dataset, img, compute_crf=False)
res_prior = res.get("prior", None)
res_pred = res["segmentation"]
fig = plot_image_scribbles(img, res_pred, fg, bg, res_prior, save=True, size=5, tight_layout=True, title="Epoch: " + str(step),
                                        legend=False)
fig                                    

In [None]:
prior_module