# Imports

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

import ast
import copy
import functools
import gc
import itertools
import logging
import operator
import os
import pathlib
import re
import socket
import sys
import time
from collections import Counter
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from pprint import PrettyPrinter, pprint
from typing import *

In [3]:
%autoreload 2

import humanize
import matplotlib
import numpy as np
import pandas as pd
import scipy as sp
import tensorflow as tf
import yaml
from matplotlib import cm, patches, pyplot as plt
from numpy import ndarray
from numpy.random import RandomState
from progressbar import progressbar as pbar
from pymicro.file import file_utils
from sklearn import metrics, metrics as met, model_selection, preprocessing
from skimage import measure as skimage_measure
import tabulate
from tensorflow import keras
from tensorflow.keras import (
    callbacks as keras_callbacks,
    layers,
    losses,
    metrics as keras_metrics,
    optimizers,
    utils,
)
from tqdm import tqdm
from yaml import YAMLObject

In [4]:
%autoreload 2

from tomo2seg import (
    analyse as tomo2seg_analyse,
    callbacks as tomo2seg_callbacks,
    data as tomo2seg_data,
    datasets as tomo2seg_datasets,
    hosts,
    losses as tomo2seg_losses,
    schedule as tomo2seg_schedule,
    slack,
    slackme,
    utils as tomo2seg_utils,
    viz as tomo2seg_viz,
    volume_sequence,
)
from tomo2seg.data import EstimationVolume, Volume
from tomo2seg.logger import add_file_handler, dict2str, logger
from tomo2seg.model import Model as Tomo2SegModel

In [5]:
# this registers a custom exception handler for the whole current notebook
get_ipython().set_custom_exc((Exception,), slackme.custom_exc)

In [6]:
logger.setLevel(logging.INFO)


# Host

In [7]:
this_host = hosts.hosts[socket.gethostname()]

# Args

In [8]:
# [manual-input]

@dataclass
class Args:
    this_nb_name: str
    volume_name: str
    volume_version: str
    partition: str  # its alias...
    
    random_state_seed: int = 42
    runid: int = field(default_factory=lambda: int(time.time()))

args = Args(
    this_nb_name = "compare-models-00.ipynb",
    volume_name=tomo2seg_datasets.VOLUME_COMPOSITE_V1[0],
    volume_version=tomo2seg_datasets.VOLUME_COMPOSITE_V1[1],
    partition="test",
)

In [9]:
logger.debug(f"args\n{dict2str(asdict(args))}")

# estimation volumes

In [14]:
volume_fullname = tomo2seg_data.Volume.name_pieces2fullname(name=args.volume_name, version=args.volume_version)

logger.debug(f"{volume_fullname=}")

In [15]:
datadir_paths = [
    tomo2seg_data.data_dir / name
    for name in os.listdir(tomo2seg_data.data_dir)
]

estimation_volumes = []

for path in datadir_paths:
    
    try:
        ev = tomo2seg_data.EstimationVolume.from_fullname(path.name)
    
    except ValueError as ex:
        
        if "not an estimation volume" not in ex.args[0]:
            raise ex
            
        continue
    
    if ev.volume_fullname == volume_fullname and ev.partition.alias == args.partition:
        estimation_volumes.append(ev)
        
logger.info(f"{len(estimation_volumes)=}")

all_estimation_volumes = copy.deepcopy(estimation_volumes)

INFO::tomo2seg::{data.py:from_fullname:502}::[2021-02-19::12:11:54.449]
Creating volume object to get partition dimensions.

INFO::tomo2seg::{data.py:from_fullname:502}::[2021-02-19::12:11:54.513]
Creating volume object to get partition dimensions.

INFO::tomo2seg::{data.py:from_fullname:502}::[2021-02-19::12:11:54.548]
Creating volume object to get partition dimensions.

ERROR::tomo2seg::{data.py:from_fullname:465}::[2021-02-19::12:11:54.579]
not enough values to unpack (expected 8, got 1)
Traceback (most recent call last):
  File "/home/users/jcasagrande/projects/tomo2seg/tomo2seg/data.py", line 461, in from_fullname
    vol_name, vol_version, partition_name, model_master_name, model_version, model_fold, model_runid, runid = full_name.split(".")
ValueError: not enough values to unpack (expected 8, got 1)
ERROR::tomo2seg::{data.py:from_fullname:465}::[2021-02-19::12:11:54.580]
not enough values to unpack (expected 8, got 1)
Traceback (most recent call last):
  File "/home/users/jcasag

# models

In [16]:
for ev in sorted(estimation_volumes, key=lambda x: x.model_name):
    print(f'"{ev.model_name}",')

"paper-unet-2d.f16-stripping-depth2.fold000.1611-686-349",
"paper-unet-2d.f16-stripping-depth4.fold000.1611-705-025",
"paper-unet-2d.f16-stripping-layernorm.fold000.1612-341-593",
"paper-unet-2d.f16-stripping-no_batchnorm.fold000.1611-679-491",
"paper-unet-2d.f16-stripping-no_data_augm.fold000.1611-701-054",
"paper-unet-2d.f16-stripping-no_dropout.fold000.1611-692-890",
"paper-unet-2d.f16-stripping-no_residual.fold000.1611-689-593",
"paper-unet-2d.f16-stripping-no_sigma.fold000.1611-697-019",
"paper-unet-2d.f16-stripping-rigid-updown.fold000.1611-683-325",
"paper-unet-2d.f16-stripping-sepconv.fold000.1611-674-526",
"paper-unet-2d.full-f08.fold000.1611-747-344",
"paper-unet-2d.full-f16.fold000.1611-743-205",
"paper-unet-2d.full-f24.fold000.1611-749-772",
"paper-unet-2d.full-f32.fold000.1612-356-814",
"paper-unet-2halfd.full-f08.fold000.1611-757-276",
"paper-unet-2halfd.full-f16.fold000.1611-738-397",
"paper-unet-2halfd.full-f24.fold000.1611-762-282",
"paper-unet-2halfd.full-f32.fold000.

In [17]:
# [manual-input]
models_to_compare = {
    # model vars
#     "2D (f0=08)": "paper-unet-2d.full-f08.fold000.1611-747-344",
#     "2D (f0=16)": "paper-unet-2d.full-f16.fold000.1611-743-205",
#     "2D (f0=24)": "paper-unet-2d.full-f24.fold000.1611-749-772",
    "2D (f0=32)": "paper-unet-2d.full-f32.fold000.1612-356-814",
    
#     "2.5D (f0=08)": "paper-unet-2halfd.full-f08.fold000.1611-757-276",
#     "2.5D (f0=16)": "paper-unet-2halfd.full-f16.fold000.1611-738-397",
#     "2.5D (f0=24)": "paper-unet-2halfd.full-f24.fold000.1611-762-282",
#     "2.5D (f0=32)": "paper-unet-2halfd.full-f32.fold000.1611-769-553",
    
#     "3D (f0=08)": "paper-unet-3d.full-f08.fold000.1611-801-655",
#     "3D (f0=16)": "paper-unet-3d.full-f16.fold000.1611-791-573",
#     "3D (f0=24)": "paper-unet-3d.full-f24.fold000.1611-807-271",
#     "3D (f0=32)": "paper-unet-3d.full-f32.fold000.1611-826-805",
    
    # model stripping
#     "depth=2": "paper-unet-2d.f16-stripping-depth2.fold000.1611-686-349",
#     "depth=4": "paper-unet-2d.f16-stripping-depth4.fold000.1611-705-025",
#     "no BatchNorm": "paper-unet-2d.f16-stripping-no_batchnorm.fold000.1611-679-491",
#     "no data augm.": "paper-unet-2d.f16-stripping-no_data_augm.fold000.1611-701-054",
#     "no Dropout": "paper-unet-2d.f16-stripping-no_dropout.fold000.1611-692-890",
#     "no residual": "paper-unet-2d.f16-stripping-no_residual.fold000.1611-689-593",
#     "no GaussianNoise": "paper-unet-2d.f16-stripping-no_sigma.fold000.1611-697-019",
#     "rigid Up/DownSampling": "paper-unet-2d.f16-stripping-rigid-updown.fold000.1611-683-325",
#     "SeparableConv": "paper-unet-2d.f16-stripping-sepconv.fold000.1611-674-526",
    "LayerNorm": "paper-unet-2d.f16-stripping-layernorm.fold000.1612-341-593",
}

models_to_compare_inv = dict(map(reversed, models_to_compare.items()))

estimation_volumes_dict = {
    models_to_compare_inv[ev.model_name]: ev
    for ev in estimation_volumes
    if ev.model_name in models_to_compare_inv
}

In [18]:
len(estimation_volumes_dict)

2

In [19]:
def get_nparams(model_name):
    t2s_model = Tomo2SegModel.build_from_model_name(model_name)
    summary = t2s_model.summary_path.read_text()
    trainable_params_line = summary.split("\n")[-4]
    trainable_params_str = trainable_params_line.split(" ")[-1]
    return int("".join(trainable_params_str.split(",")))

In [20]:
def get_model_history(model_name):
    t2s_model = Tomo2SegModel.build_from_model_name(model_name)
    df = pd.read_csv(t2s_model.history_path).set_index("epoch")
    df["batches"] = df["train.epoch_size"].cumsum()
    df["crops"] = (df["train.epoch_size"] * df["train.batch_size"]).cumsum()
    import operator
    crop_nvoxels = functools.reduce(operator.mul, ast.literal_eval(df["train.crop_shape"][0]))
    df["voxels"] = (df["train.epoch_size"] * df["train.batch_size"] * crop_nvoxels).cumsum()
    df["seconds"] = df["seconds"].cumsum()
    df["hours"] = df["seconds"] / 60 / 60
    df["val_loss_cummin"] = np.array([
        df["val_loss"].values[:i].min()
        for i in [1] + list(range(1, df["val_loss"].shape[0])) 
    ])
    return df

In [21]:
h = get_model_history("paper-unet-2d.full-f08.fold000.1611-747-344")

In [24]:
from tomo2seg.analyse_pred import AnalysePredOuputs
is_new_dir = True

def get_records(estimation_volumes_dict, metrics=("jaccard", )):
    
    records = []

    metric_cols = {}

    for model_alias, ev in estimation_volumes_dict.items():

        record = {}
        
        hist = get_model_history(ev.model_name)

        record["model"] = ev.model_name
        record["alias"] = model_alias
        record["nparams"] = get_nparams(ev.model_name)
        
        if is_new_dir:
            outputs = AnalysePredOuputs(ev.dir / "pred-analysis")
            ev_classif_report = pd.read_csv(outputs.classification_report_table_csv).set_index("class/average")
        else:
            ev_classif_report = pd.read_csv(ev.classification_report_table_exact_csv_path).set_index("class/average")

        for m in metrics:

            mcols = metric_cols[m] = []

            for row in ev_classif_report.index:
                col = f"{m}.{row}"
                mcols.append(col)
                record[col] = ev_classif_report.loc[row][m]
                
        record["estimation_volume"] = ev.fullname
        record["training-hours"] = hist["hours"].values[-1]
        records.append(record)
        
    return records
    

In [25]:
df = pd.DataFrame.from_records(get_records(estimation_volumes_dict)).set_index("model")    

In [26]:
df.sort_index()[["alias", "nparams", "jaccard.matrix", "jaccard.fiber", "jaccard.porosity", "jaccard.macro", "training-hours"]]

Unnamed: 0_level_0,alias,nparams,jaccard.matrix,jaccard.fiber,jaccard.porosity,jaccard.macro,training-hours
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
paper-unet-2d.f16-stripping-layernorm.fold000.1612-341-593,LayerNorm,3302883,0.987484,0.952605,0.645879,0.861989,3.79456
paper-unet-2d.full-f32.fold000.1612-356-814,2D (f0=32),13195203,0.988192,0.95479,0.663305,0.868763,2.830493


# training histories

In [27]:
dft_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

def plot(xaxis, yaxis, models_to_compare, color_map):
    
    hists = {
        mod: get_model_history(mod)
        for mod in models_to_compare
    }
    
    fig, axs = plt.subplots(
        nrows := 1, ncols := 2, figsize=(ncols * (sz := 6) * 2, nrows * sz), dpi=120,
    #     sharex=True,
        gridspec_kw=dict(hspace=sz/12, wspace=sz/36)
    )

    ax, axzoom = axs

    for mod, hist in hists.items():

        xs = hist[xaxis] if xaxis != "epoch" else hist.index
        ys = hist[yaxis]

        color = next(v for k, v in color_map.items() if k in mod)

        plot_kwargs = dict(
            label=mod,
            ls="--" if "sep" in mod else "-",
            color=color,
            linewidth=.7,
        )

        ax.plot(xs, ys, **plot_kwargs)
        axzoom.plot(xs, ys, **plot_kwargs)

    # configs in common
    for ax_ in axs:
        ax_.yaxis.set_major_formatter(plt.FormatStrFormatter("%.2f"))
        ax_.set_ybound(lower=0)
        ax_.set_xlabel(xaxis)
        ax_.legend()
        ax_.set_ylabel("jaccard2 (lower is better)")

    ax.set_ybound(upper=.7)
    ax.set_title(f"{yaxis} history")

    axzoom.set_yscale("log")
    axzoom.set_ybound(upper=.03)
    axzoom.set_title("zoom")

## conv types

In [28]:
compare_conv_type = [
    "unet2d-sep.vanilla03-f16.fold000.1606-575-226",
    "unet2d.vanilla03-f16.fold000.1606-505-109",
    "unet2halfd-sep.vanilla03-f16.fold000.1606-729-672",
    "unet2halfd.vanilla03-f16.fold000.1606-683-705",
    "unet3d.vanilla03-f08.fold000.1606-842-005",
    "unet3d.vanilla03-f16.fold000.1606-750-939",
]

color_map = {
    "2d": dft_colors[0],
    "2halfd": dft_colors[1],
    "3d.vanilla03-f08": dft_colors[2],
    "3d.vanilla03-f16": dft_colors[3],
}

plot("epoch", "val_loss", compare_conv_type, color_map)
# plot("hours", "val_loss", compare_conv_type, color_map)
plot("epoch", "val_loss_cummin", compare_conv_type, color_map)
plot("voxels", "val_loss_cummin", compare_conv_type, color_map)
plot("hours", "val_loss_cummin", compare_conv_type, color_map)

FileNotFoundError: [Errno 2] No such file or directory: '/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-sep/unet2d-sep.vanilla03-f16.fold000.1606-575-226/history.csv'

## conv types (no separable)

In [None]:
compare_conv_type = [
    "unet2d.vanilla03-f16.fold000.1606-505-109",
    "unet2halfd.vanilla03-f16.fold000.1606-683-705",
#     "unet3d.vanilla03-f08.fold000.1606-842-005",
    "unet3d.vanilla03-f16.fold000.1606-750-939",
]

color_map = {
    "2d": dft_colors[0],
    "2halfd": dft_colors[1],
    "3d.vanilla03-f08": dft_colors[2],
    "3d.vanilla03-f16": dft_colors[3],
}

# plot("epoch", "val_loss", compare_conv_type, color_map)
# plot("hours", "val_loss", compare_conv_type, color_map)
plot("epoch", "val_loss_cummin", compare_conv_type, color_map)
plot("voxels", "val_loss_cummin", compare_conv_type, color_map)
plot("hours", "val_loss_cummin", compare_conv_type, color_map)

## crop sizes

### 2d

In [None]:
compare_crop_sizes = [
    "unet2d.crop48-f16.fold000.1607-530-580",
    "unet2d.crop112-f16.fold000.1607-533-765",
    "unet2d.vanilla03-f16.fold000.1606-505-109",
]

color_map = {
    "crop48": dft_colors[0],
    "crop112": dft_colors[1],
    "vanilla03": dft_colors[2],
}

In [None]:
# plot("epoch", "val_loss", compare_crop_sizes, color_map)
# plot("hours", "val_loss", compare_crop_sizes, color_map)
plot("epoch", "val_loss_cummin", compare_crop_sizes, color_map)
plot("hours", "val_loss_cummin", compare_crop_sizes, color_map)

### 2.5d

In [None]:
compare_crop_sizes = [
    "unet2halfd.crop112-f16.fold000.1607-788-628",
    "unet2halfd.vanilla03-f16.fold000.1606-683-705",

    "unet2halfd-sep.crop112-f16.fold000.1607-789-290",
    "unet2halfd-sep.vanilla03-f16.fold000.1606-729-672",
]

color_map = {
    "crop112": dft_colors[1],
    "vanilla03": dft_colors[2],
}

In [None]:
# plot("epoch", "val_loss", compare_crop_sizes, color_map)
# plot("hours", "val_loss", compare_crop_sizes, color_map)
plot("epoch", "val_loss_cummin", compare_crop_sizes, color_map)
plot("hours", "val_loss_cummin", compare_crop_sizes, color_map)

### 3d

In [None]:
compare_crop_sizes = [
    "unet3d.vanilla03-f08.fold000.1606-842-005",
    "unet3d.crop96-f08.fold000.1607-109-265",
    "unet3d.crop112-f12.fold000.1607-466-349",
    "unet3d.crop304-f16.fold000.1607-790-699",
]

color_map = {
    "vanilla03": dft_colors[0],
    "crop96": dft_colors[1],
    "crop112": dft_colors[2],
    "crop304": dft_colors[3],
}

In [None]:
# plot("epoch", "val_loss", compare_crop_sizes, color_map)
# plot("hours", "val_loss", compare_crop_sizes, color_map)
plot("epoch", "val_loss_cummin", compare_crop_sizes, color_map)
plot("hours", "val_loss_cummin", compare_crop_sizes, color_map)