In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import pandas as pd
import seaborn as sns
import re
import h5py

from functools import reduce
from itertools import combinations
from ttt.utils import listfiles
from mlr.database.utils import read_metadata, write_metadata
from matplotlib import cm

import ipywidgets as widgets
from ipywidgets import (IntSlider, RadioButtons, SelectMultiple, Layout, Checkbox, fixed, interact, FloatRangeSlider, Button)
from matplotlib.lines import Line2D

In [2]:
def read_training_history(path):
    history_file = h5py.File(path,'r')
    history_dict = {x:np.array(y) for x, y in history_file.items()} 
    history_file.close()
    return history_dict

In [3]:
root_dir_base = pathlib.Path("/media/files/segmentation_networks/")
root_dirs = [root_dir_base.joinpath("baseline_small_datasets_HQ_redux_1_6_2022"),
            root_dir_base.joinpath("small_NP_networks_redux_1_6_2022"),
            root_dir_base.joinpath("substrate_networks_redux_1_7_2022"),
            root_dir_base.joinpath("super_network_1_25_2022"),
            root_dir_base.joinpath("transfer_learning_all_with_callbacks_2_2_2022")
            ]

root_dirs = [folder.joinpath("trained_models") for folder in root_dirs]
model_folders = [[x for x in listfiles(r) if x.is_dir()] for r in root_dirs]
model_folders = reduce(lambda x,y: x+y, model_folders)

In [4]:
## load all the metadata from the networks
metadata_list = []
for folder in model_folders:
    model_metadata = read_metadata(folder.joinpath("metadata.json"))    
    model_id = re.search("[0-9]+_[0-9]+", str(folder).rsplit("/")[-1])[0]
    model_metadata["ID"] = model_id

    if "transfer" in str(folder):    
        tl = True
        history_path = "sm_unet_transferLearnWeights_" + model_metadata["backbone"]+"_history.h5"
        history_dict = read_training_history(folder.joinpath(history_path))
        model_metadata.update(history_dict)
        model_metadata["best_val_f1"] = history_dict["val_f1-score"][np.argmin(history_dict["val_loss"])]
        model_metadata["best_CdSe_f1"] = history_dict["CdSe_f1-score"][np.argmin(history_dict["CdSe_loss"])]
        model_metadata["best_Kath_f1"] = history_dict["Kath_f1-score"][np.argmin(history_dict["Kath_loss"])]
        model_metadata["CdSe_f1_at_val_peak"] = history_dict["CdSe_f1-score"][np.argmin(history_dict["val_loss"])]
        model_metadata["Kath_f1_at_val_peak"] = history_dict["Kath_f1-score"][np.argmin(history_dict["val_loss"])]
        model_metadata["series"] = "transfer"
        model_metadata["epoch"] = np.array([i/len(history_dict["val_f1-score"]) for i in np.arange(len(history_dict["val_f1-score"]))])
    else:
        tl = False
        model_metadata["series"] = str(folder).rsplit("/")[-3].split("_")[0]
    model_metadata["transfer_learned"] = tl
    model_metadata["folder"] = folder
    metadata_list.append(model_metadata)
        
df = pd.DataFrame(metadata_list)
df = pd.concat([df, pd.DataFrame(df['schedule'].to_list(), columns = ['schedule_rate', 'schedule_timing'])], axis=1)

In [5]:
baseline_IDs = list(df[df["series"]=="baseline"]["ID"])
small_IDs = list(df[df["series"]=="small"]["ID"])
sub_IDs = list(df[df["series"]=="substrate"]["ID"])
super_IDs =  list(df[df["series"]=="super"]["ID"])

def return_orig_series(o_ID):
    if o_ID in baseline_IDs:
        return "baseline"
    elif o_ID in small_IDs:
        return "small"
    elif o_ID in sub_IDs:
        return "substrate"
    elif o_ID in super_IDs:
        return "super"
    else:
        return None

def reduce_wrap(X):
    try:
        return reduce(lambda x,y: x+y, X)
    except:
        return np.NAN

In [6]:
df["orig_series"] =  df["orig_model_ID"].apply(return_orig_series)
df["substrate_thicknesses_str"] = df["substrate_thicknesses"].apply(reduce_wrap)

Baseline can be matched on target dose

SmallNP can be matched on N_defocus and N_structures

Substrate can be matched on substrate_thicknesses

Super networks can be matched on schedule_rate and alpha_0

these should be propagated to the transfer learned networks' metadata

Visualizing transfer learning generalization and performance dynamics with training histories:

For visualization:

want a widget that has:
- a slider for model ID selection [x]
- a dropdown menu to select datasets [x]
- a dropdown menu to select validation splits [x]
- a radio button set to select normalization scheme [x]
- a radio button to do either aggregate plotting (e.g., 5 faint lines per training base network) or averaged plotting (seaborn error bars?)
- a checkbox to plot all models in filters (disables slider and plots all examples) [x]
- checkbox to plot all models with same syntehtic training conditions [ ]


Normalization can be either: \
A) fraction of best score achieved during transfer learning stage \
B) fraction of score before transfer learning \
C) fraction of max(A,B)

In [7]:
print(df.columns)

Index(['alpha_0', 'backbone', 'batch_size', 'exp_f1-score',
       'exp_f1-score_CdSe', 'exp_f1-score_katherine', 'exp_iou',
       'exp_iou_CdSe', 'exp_iou_katherine', 'exp_loss', 'exp_loss_CdSe',
       'exp_loss_katherine', 'schedule', 'seed', 'target_dose', 'ID', 'series',
       'transfer_learned', 'folder', 'N_defocus', 'N_structures',
       'substrate_thicknesses', 'expt_generator_seed', 'expt_val_split',
       'orig_model_ID', 'CdSe_f1-score', 'CdSe_iou_score', 'CdSe_loss',
       'Kath_f1-score', 'Kath_iou_score', 'Kath_loss', 'f1-score', 'iou_score',
       'loss', 'lr', 'val_f1-score', 'val_iou_score', 'val_loss',
       'best_val_f1', 'best_CdSe_f1', 'best_Kath_f1', 'CdSe_f1_at_val_peak',
       'Kath_f1_at_val_peak', 'epoch', 'schedule_rate', 'schedule_timing',
       'orig_series', 'substrate_thicknesses_str'],
      dtype='object')


In [8]:
plot_df = df[df["transfer_learned"]==True]

In [9]:
global slider_size
slider_size = len(plot_df)

def filter_dataframe(data_frame, series, splits, plot_all_starters, plot_syn_cond, N):
    filters = []
    filters.append(reduce(lambda x,y: x | y, [data_frame["expt_val_split"]==x for x in splits]))

    if not "All" in series:
        filters.append(reduce(lambda x,y: x | y, [data_frame["orig_series"]==x.lower() for x in series]))

    if plot_all_starters:
        # filters.append(data_frame["orig_model_ID"]==data_frame.iloc[N]["orig_model_ID"])
        
        tmp_data_frame = data_frame[reduce(lambda x,y: x & y, filters)]
        return tmp_data_frame[tmp_data_frame["orig_model_ID"]==tmp_data_frame.iloc[N]["orig_model_ID"]]
    else:
        return data_frame[reduce(lambda x,y: x & y, filters)]

def match_training(data_frame, N):

    o_series = data_frame.iloc[N]["orig_series"]

    if o_series == "baseline":
        keys = ["target_dose"]
    elif o_series == "small":
        keys = ["N_defocus", "N_structures"]
    elif o_series == "substrate":
        keys = ["substrate_thicknesses_str"]
    elif o_series == "super":
        keys = ["schedule_rate", "alpha_0"]

    filters = [data_frame[k]==data_frame.iloc[N][k] for k in keys]
    return data_frame[reduce(lambda x,y: x & y, filters)]

def plot_row(data_frame, N, **kwargs):
    r_df = filter_dataframe(data_frame, 
                            series=kwargs["dataset"],
                            splits=kwargs["val_split"],
                            plot_all_starters=kwargs["plot_all_starters"],
                            plot_syn_cond=kwargs["plot_syn_cond"],
                            N=N)
        
    global slider_size
    slider_size = len(r_df)

    if kwargs["plot_syn_cond"]:
        r_df = match_training(r_df, N)

    fig = plt.figure(figsize=(12,6))
    ax = fig.gca()
    ax.set_ylim(kwargs["y_range"])
    ax.set_xlim([0,1])
    ax.grid()



    # if plot all starters or aggregate plot, reduce line widthness and plot whole dataframe
    # else, plot only the target row
    if kwargs["plot_all_starters"] or kwargs["aggregate_plot"]:
        rows = [x for __, x in r_df.iterrows()]
        linewidth = 0.8
    elif kwargs["plot_syn_cond"]:
        rows = [x for __, x in r_df.iterrows()]
        linewidth = 1.2
    else:
        rows = [r_df.iloc[N]]
        linewidth = 2.0

    plot_keys = ("val_f1-score", "CdSe_f1-score", "Kath_f1-score")
    key_map = {x:y for x, y in zip(plot_keys, ("exp_f1-score", "exp_f1-score_CdSe", "exp_f1-score_katherine"))}
    plot_filter = {plot_keys[0]: "Kate Au" in kwargs["exp_dataset"],
                    plot_keys[1]: "Kate CdSe" in kwargs["exp_dataset"],
                    plot_keys[2]: "Katherine Au" in kwargs["exp_dataset"],
                    }

    for row in rows:
        for i, y in enumerate(plot_keys):
            if plot_filter[y]:
                if "Max" in kwargs["normalization"]:
                    div = np.max([row[key_map[y]], row["best_"+y.rsplit("-")[0]]])
                elif "Peak" in kwargs["normalization"]:
                    div = row["best_"+y.rsplit("-")[0]]
                elif "Synthetic" in kwargs["normalization"]:
                    div = row[key_map[y]]
                else:
                    div = 1

                if kwargs["shift"]:
                    shift = row[key_map[y]]
                else:
                    shift = 0.0

                ax.plot(row["epoch"], (row[y]-shift)/div, color=sns.color_palette("pastel")[i], linewidth=linewidth)

    custom_lines = [Line2D([0], [0], color=sns.color_palette("pastel")[0], lw=4),
                Line2D([0], [0], color=sns.color_palette("pastel")[1], lw=4),
                Line2D([0], [0], color=sns.color_palette("pastel")[2], lw=4)]

    ax.legend(custom_lines, ["Kate Au", "Kate CdSe", "Katherine Au"],
                loc="upper right",
                bbox_to_anchor=(1.18,1.02))
    ax.axhline(0.0, color=(1,0,0), linestyle="--")
    ax.axhline(1.0, color=(1,0,0), linestyle="--")
    ax.set_xlabel("Normalized epoch")
    ax.set_ylabel("Shift in Dice Score")
    plt.show()

def plot_lines_widget(data_frame):
    w_data_filter = SelectMultiple(
        options=['All', 'Baseline', 'Small', 'Substrate', 'Super'],
        value=('All',),
        description='Syn. Dataset',
        disabled=False,
    )

    w_exp_filter = SelectMultiple(
        options=['Kate Au', 'Kate CdSe', 'Katherine Au'],
        value=('Kate Au', 'Kate CdSe', 'Katherine Au'),
        description='Exp. Dataset',
        disabled=False,
    )

    w_normalization =RadioButtons(
            options=['Synthetic Score', 'Peak Transfer Score', 'Max(Synthetic, Transfer)', 'None'],
            description='Normalization',
            disabled=False
        )

    w_split_filter = SelectMultiple(
        options=[0.2, 0.4, 0.5, 0.6, 0.8],
        value=(0.2,),
        description='Val. Data Split',
        disabled=False,
    )

    w_line_select = IntSlider(value=0, min=0, max=slider_size-1, description="Model #")

    w_starting_network = Checkbox(value=False, description="Plot same starting network", disabled=True)
    w_training_conditions = Checkbox(value=False, description="Plot same synthetic training")
    w_aggregate = Checkbox(value=False, description="Aggregate plot")
    w_y_range = FloatRangeSlider(value=[0, 2.0],
                                min=-1.0,
                                max=2.5,
                                step=0.1,
                                description='y-axis range:',
                                disabled=False,
                                continuous_update=False,
                                orientation='horizontal',
                                readout=True,
                                readout_format='.1f',
                            )

    w_shift = Checkbox(value=False, description="Shift by synthetic score")

    return interact(plot_row, data_frame=fixed(data_frame),
                             N=w_line_select,
                             dataset=w_data_filter,
                             val_split=w_split_filter,
                             exp_dataset=w_exp_filter,
                             normalization=w_normalization,
                             shift=w_shift,
                             plot_all_starters=w_starting_network,
                             plot_syn_cond=w_training_conditions,
                             aggregate_plot=w_aggregate,
                             y_range=w_y_range,
                             )

In [10]:
W = plot_lines_widget(plot_df)

idx = -1
for i, child in enumerate(W.widget.children):
    if isinstance(child, IntSlider):
        idx = i
        break

slider = W.widget.children[idx]

def on_value_change(change):
    global slider_size
    slider.max = slider_size-1

for child in W.widget.children:
    child.observe(on_value_change, names='value')

interactive(children=(IntSlider(value=0, description='Model #', max=1274), SelectMultiple(description='Syn. Da…

In [None]:
np.linalg.norm(center-neighbor)