In [None]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
from awesome.analytics.result_model import ResultModel
from awesome.util.path_tools import get_project_root_path, get_package_root_path
import os
import torch
import re
from awesome.util.format import latex_postprocessor
import pandas as pd
from typing import Any, Dict
os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

In [None]:
from awesome.analytics.result_comparison import ResultComparison

paths = [
         "./runs/fbms_local/eval/unet/joint_realnvp/2024-01-11"
         ]
models = []

for path in paths:
    for folder in os.listdir(path):
        if folder.startswith("old") or folder.startswith("log"):
            continue
        model = ResultModel.from_path(os.path.join(path, folder))
        models.append(model)

import re
p = r"#?(?P<cfg_num>\d+)?_?(?P<net>[A-z0-9]+)_?(?P<feat>(\+\w+)*)\_(?P<date>\d{2}_\d{2}_\d{2})\_(?P<time>\d{2}_\d{2}_\d{2})"
pattern = re.compile(p)

for model in models:
    match = pattern.match(model.name)
    model_name = None
    feat = []
    if match:
        model_name = match.group('net').strip("_")
        features = match.group('feat')
        if features is not None and features != "":
            feat = features.strip("+").split("+")
            if not any(["seed" in x for x in feat]):
                feat.append("seed42")
    else:
        print('No match for', model.name)
    model_name = model_name.replace("NET", "Net")
    model.display_name = model_name + " " + " ".join(feat)
    model.features = list(feat)
    model.config.result_directory = "final_mask"
    model.save_config()


# Resort the models by name to get a meaningful table order

_order = []

models = sorted(models, key=lambda m: _order.index(m.name) if m.name in _order else 0)

comparison = ResultComparison(models)
comparison.assign_numbers(force=True)

os.environ['PLOT_OUTPUT_DIR'] = comparison.output_folder

save_args = dict(transparent=False, save=True, dpi=300, ext=["png", "pdf"])

models

In [None]:
from typing import Tuple

metrics = [
    "eval/epoch/MeanForegroundBinaryMIOU" ,
    "eval/epoch/MeanPriorForegroundBinaryMIOU",
    "eval/epoch/MeanPixelAccuracy",
    "eval/epoch/MeanPriorPixelAccuracy",
    "eval/epoch/MeanCRFForegroundBinaryMIOU",
    "eval/epoch/MeanCRFPixelAccuracy",
]

col_mapping = {
    "eval/epoch/MeanForegroundBinaryMIOU": "IoU",
    "eval/epoch/MeanCRFForegroundBinaryMIOU": "CRF IoU",
    "eval/epoch/MeanPixelAccuracy": "Acc.",
    "eval/epoch/MeanCRFPixelAccuracy": "CRF Acc.",
    "eval/epoch/MeanPriorPixelAccuracy" : "Prior Acc.",
    "eval/epoch/MeanPriorForegroundBinaryMIOU": "Prior IoU" 
}

index_mapping = {
    0: "Baseline",
    15: "Joint"
}

new_colmapping = {}
for k, v in dict(col_mapping).items():
    for k_i, v_i in dict(index_mapping).items():
        new_colmapping[k + "_" + str(k_i)] = v + " " + v_i

col_mapping = new_colmapping

def extract_features(model: ResultModel) -> Dict[str, Any]:
    res = dict()
    
    res['joint'] = "joint" in model.features
    
    res['model_name'] = model.config.name.split(" ")[0]
    model_features = list(model.features)
    
    if res['joint']:
        model_features.remove("joint")
    
    seed = next((x for x in model_features if "seed" in x), None)
    if seed is None:
        seed = model.run_config.seed
        res['seed'] = seed
    else:
        model_features.remove(seed)
        res['seed'] = int(seed.replace("seed", ""))

    if "REFIT" in model_features:
        model_features.remove("REFIT")
        res['prior'] = "refit"

    if "original" in model_features:
        model_features.remove("original")
        res['prior'] = "original"
    if "retrain" in model_features:
        model_features.remove("retrain")
        res['prior'] = "refit"
    if "retrain_xy" in model_features:
        model_features.remove("retrain_xy")
        res['prior'] = "refit"

    elif "convex" in model_features:
        model_features.remove("convex")
        res['prior'] = "convex"
    elif "diffeo" in model_features:
        model_features.remove("diffeo")
        res['prior'] = "diffeo"
    else:
        res['prior'] = "none"

    if "only_prior" in model_features:
        model_features.remove("only_prior")
        res['prior'] = res['prior'] + "+only_prior"

    if "all_frames" in model_features:
        model_features.remove("all_frames")
        res['prior'] = res['prior'] + "+all_frames"
    if "deeper" in model_features:
        model_features.remove("deeper")
        res['prior'] = res['prior'] + "+deeper"

    dataset_name = model_features.pop(0)
    res['dataset_name'] = dataset_name

    assert len(model_features) == 1, f"Multiple features {model_features} in model {model.output_path}"
    res['feature_type'] = model_features[0]

    return res

df = comparison.metric_table(metrics, 
                             ref="all", 
                             mode="max",
                        formatting=False)

df = df.reset_index()

def extract_ft_row(row: pd.Series) -> Tuple[str, bool, str, int]:
    name = row['index']
    model = [m for m in models if m.name == name][0]
    res = extract_features(model)
    return (res['model_name'], res['joint'], res['feature_type'], res['seed'], res['dataset_name'], res['prior'])

df[['model_name', 'joint', 'feature_type', 'seed', 'dataset_name', "prior"]] = df.apply(extract_ft_row, axis=1, result_type="expand")

df = df[['model_name', 'dataset_name'] + list(col_mapping.keys()) + ['joint', 'feature_type', 'seed', "prior"]]

grouped = df.groupby(['model_name', 'joint', 'feature_type'])

display_df = df.rename(columns=col_mapping)
display(display_df)

In [None]:
print(display_df.to_markdown())

In [None]:
display_df.describe()

In [None]:
print(display_df.describe().to_markdown())

In [None]:
ev_ds = set(display_df['dataset_name'].values.tolist())

should = set([
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   'cats04',
   'cats05',
   'horses01',
   'horses03',
   'marple1',
   'marple10',
   'marple11',
   'marple5',
   'meerkats01',
   'people04',
   'rabbits01',
   ])
should - ev_ds

In [None]:
renamed_df = df.rename(columns=col_mapping)

diffeo_df = renamed_df#[(renamed_df['prior'] == 'diffeo') | (renamed_df['prior'] == 'diffeo+only_prior')]
order = ['model_name', 'joint', 'feature_type', 'dataset_name', 'IoU', 'Prior IoU', 'Acc.', 'Prior Acc.', 'prior']
diffeo_df[order]

In [None]:
print(diffeo_df[order].to_markdown())

In [None]:

display_df = diffeo_df[order]

display_df.loc[:, "prior"] = display_df["prior"].values + display_df["joint"].apply(lambda x: "+joint" if x else "")
display_df.loc[:, "prior"] = display_df["prior"].str.replace("diffeo+", "")

display_df.loc[:, "prior"] = display_df["prior"].str.replace("only_prior+all_frames+deeper", "Prior fit only (all+deeper)")

display_df.loc[display_df['feature_type'] == "edge", "prior"] = display_df[display_df['feature_type'] == "edge"]["prior"].str.replace("refit", "Refit")
display_df.loc[display_df['feature_type'] == "edgexy", "prior"] = display_df[display_df['feature_type'] == "edgexy"]["prior"].str.replace("refit", "Refit XY")



display_df.loc[:, "prior"] = display_df["prior"].str.replace("only_prior+all_frames", "Prior fit only (all)")
display_df.loc[:, "prior"] = display_df["prior"].str.replace("only_prior+deeper", "Prior fit only (GT+deeper)")
display_df.loc[:, "prior"] = display_df["prior"].str.replace("only_prior", "Prior fit only (GT)")
display_df.loc[:, "prior"] = display_df["prior"].str.replace("deeper+joint", "Joint Training (deeper)")
display_df.loc[:, "prior"] = display_df["prior"].str.replace("joint", "Joint Training")

prior_categories = [
                    "Refit",
                    "Refit XY",
                    "Prior fit only (GT)", 
                    "Prior fit only (GT+deeper)", 
                    "Prior fit only (all)", 
                    "Prior fit only (all+deeper)", 
                    "Joint Training",
                    "Joint Training (deeper)",
                    ]


display_df = display_df.drop(columns=["joint", "feature_type", "model_name"])
value_columns = ["IoU", "Prior IoU", "Acc.", "Prior Acc."]

#display_df.loc[:, "prior"] = display_df["prior"].str.replace("only_prior", "Prior fit only (GT)")
dataset_names = sorted(display_df["dataset_name"].unique())
def _order_fnc(x):
    if x.name == "dataset_name":
        return [dataset_names.index(v) for v in x]
    if x.name == "prior":
        return [prior_categories.index(v) for v in x]
    return 0

#display_df.loc[:, value_columns] = display_df[value_columns].applymap(lambda x: "{:.3f}".format(x))
display_df = display_df.set_index(["dataset_name", "prior"]).sort_values(by=["dataset_name", "prior"], key=_order_fnc)

display_df

In [None]:
# grps = display_df.groupby(level="prior")
# for name, grp in grps:
#     print(grp.to_markdown())

display_df.groupby(level="prior").mean()

In [None]:
out= (
        display_df.reset_index()
          .rename(columns={'idx1': '', 'idx2': ''})
          .to_markdown(tablefmt='github', index=False)
      )
print(out)

In [None]:
import logging
from awesome.util.temporary_property import TemporaryProperty
import awesome.run.functions as F
import matplotlib.pyplot as plt
import ipywidgets as widgets
from matplotlib.axes import Axes

model: ResultModel = models[0]


image_index = 0


dataset_names_unique = diffeo_df['dataset_name'].unique().tolist()

def qualitative_comparison(dataset_name_idx: int, image_index: int = 0):
    dataset_name = dataset_names_unique[dataset_name_idx]
    filtered_models_df = df[(df['dataset_name'] == dataset_name)]
    if len(filtered_models_df) == 0:
        return "No models found for input combination: dataset_name={}".format(dataset_name)

    display_models =  [x for x in models if x.name in filtered_models_df['index'].values]

    size = 10
    fig, ax = plt.subplots(1, len(display_models), figsize=(size * len(display_models), size))
    
    if isinstance(ax, Axes):
        ax = [ax]

    for i, model in enumerate(display_models):
        with TemporaryProperty(model, getitem_mask_mode="both"):
            runner = model.get_runner()
            dataloader = runner.dataloader
            if image_index >= len(dataloader):
                image_index = 0
                logging.warning(f"Image index out of bounds for dataset: {dataset_name} len is: {len(dataloader)}. Resetting to 0")
            if len(model) == 0:
                logging.warning(f"No results for model {model.name}. Skipping.")
                continue
            res_mask, prior_mask = model[image_index]
            image, ground_truth, _input, targets, fg, bg, prior_state = F.prepare_input_eval(dataloader, model, image_index)
            
            fig = F.plot_image_scribbles(image=image, 
                                        inference_result=res_mask, 
                                        foreground_mask=fg, 
                                        background_mask=bg,
                                        prior_result=prior_mask, 
                                        tight=True, 
                                        background_value=0,
                                        ax=ax[i],
                                        )
            ax[i].set_title(model.display_name)
    return fig

ds_widget = widgets.IntSlider(min=0, max=len(dataset_names_unique), value=0, description='Dataset:')
index_widget = widgets.IntSlider(min=0, max=100, value=0, description='Image index:')

def update_index_range(*args):
    # load the dataset
    dataset_name = ds_widget.value
    models_df = df[(df['dataset_name'] == dataset_name)]
    if len(models_df) == 0:
        return
    m: ResultModel = next((x for x in models if x.name in models_df['index'].values), None)
    if m is None:
        return
    index_widget.max = len(m.get_runner().dataloader) - 1

ds_widget.observe(update_index_range, 'value')


out = widgets.interact_manual(qualitative_comparison, **{'dataset_name_idx': ds_widget, 'image_index': index_widget})

text_widget = widgets.Textarea(
    value=", ".join([str(i) + ": " +str(x) for i, x in enumerate(dataset_names_unique)]), description="Dataset Mapping:", width=300
)

#row1 = widgets.HBox([ds_widget, index_widget])
row2 = widgets.HBox([text_widget])
ui = widgets.VBox([row2])
display(ui)
display(out)

In [None]:
dataset_names_unique

In [None]:
qualitative_comparison(2, 0)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(res_mask)

In [None]:
from awesome.run.functions import get_result, split_model_result, plot_image_scribbles

runner = extract_models[15].get_runner()
dataloader = runner.dataloader
agent = extract_models[15].get_agent(-1)
model = agent._get_model()
model_gets_targets = agent.model_gets_targets

res, ground_truth, img, fg, bg = get_result(model, dataloader, 10, model_gets_targets=model_gets_targets)
res = split_model_result(res, model, dataloader, img)
res_prior = res.get("prior", None)
res_pred = res["segmentation"]
boxes = res.get("boxes", None)
labels = res.get("labels", None)

p = os.path.join(runner.agent.agent_folder, "pretrain_priors")
os.makedirs(p, exist_ok=True)

iterations = 2000
fig = plot_image_scribbles(image=img,
                    inference_result=res_pred,
                    foreground_mask=fg,
                    background_mask=bg,
                    prior_result=res_prior,
                    save=True,
                    path=os.path.join(p, f"prior_{i}_{iterations}.png"),
                    size=10,
                    title=f"Prior Epoch: {iterations}", open=True)
fig

# Extracting state dict from checkpoint

In [None]:
from datetime import datetime
#assert False, "Stop here"
from collections import OrderedDict
extract_models = [x for x in comparison.models if 'REFIT' in x.name and 'edge' in x.name and not 'edgexy' in x.name]
pretrain_path = "./data/checkpoints/refit_unet_uncertainty/{}/model_{}_unet.pth"
date = datetime.now().strftime("%y_%m_%d") 
for model in extract_models:
    datasetname = model.config.name.split(" ")[1]
    agent = model.get_agent(-1)
    state_dict = agent._get_model().segmentation_module.state_dict()
    os.makedirs(os.path.dirname(pretrain_path.format(date, datasetname)), exist_ok=True)
    torch.save(state_dict, pretrain_path.format(date, datasetname))

In [None]:
from datetime import datetime
from collections import OrderedDict
#assert False, "Stop here"

extract_models = [x for x in comparison.models if 'REFIT' in x.name and 'edgexy' in x.name]
pretrain_path = "./data/checkpoints/refit_spatial_unet_uncertainty/{}/model_{}_unet.pth"
date = datetime.now().strftime("%y_%m_%d") 
for model in extract_models:
    datasetname = model.config.name.split(" ")[1]
    agent = model.get_agent(-1)
    state_dict = agent._get_model().segmentation_module.state_dict()
    os.makedirs(os.path.dirname(pretrain_path.format(date, datasetname)), exist_ok=True)
    torch.save(state_dict, pretrain_path.format(date, datasetname))