In [None]:
import os
import pandas as pd
import numpy as np
import  seaborn as sns
import matplotlib.pyplot as plt
import itertools
from matplotlib import cm
from matplotlib.gridspec import GridSpec
from mpl_toolkits import axes_grid1

%load_ext autoreload

%autoreload 2

from source.plotting_utils import plot_heatmap_colorbar, plot_heatmap_histogram, visualize_feature_hist

In [None]:
logdir = "../outputs/har/interpretable/"
run = "split" # diagnosis, prognosis

In [None]:
center_weight_plots = True
auto_scale = True

def get_auto_scale(type, m, center=center_weight_plots):
    safety = 0.25
    if type == "b":
        lim = np.abs(m).max()
        return - (lim + safety), lim + safety
    if type == "w":
        max = np.max(m)
        min = np.min(m)
        if center:
            lim = np.max(np.abs([min, max]))
            return - (lim + safety), lim + safety
        else:
            lim = np.max([max - 1, 1 - min])
            return - (lim + safety) + 1, lim + safety + 1

In [None]:
rundir = logdir + run + "/"
plot_out_dir = rundir + "weight_plots/"
os.makedirs(plot_out_dir, exist_ok=True)

w_df = pd.read_csv(rundir + "feature_w.csv", index_col=0)
b_df = pd.read_csv(rundir + "feature_b.csv", index_col=0)
w = w_df.values
b = b_df.values

In [None]:
if center_weight_plots:
    w = w - np.mean(w, axis=0)
    b = b - np.mean(b, axis=0)


In [None]:
# colorbar
if auto_scale:
    vmin, vmax = get_auto_scale("w", w)
else:
    if center_weight_plots:
        vmin, vmax = -2.0, 2.0
    else:
        vmin, vmax = -0.5, 2.5

print("weight min, max:", vmin, vmax)

In [None]:
plot_heatmap_colorbar(
    m=w,
    cmap="PuOr", # "PuOr"-weights,  "RdGy"-bias
    xlabels=w_df.columns,
    ylabels=w_df.index,
    path=plot_out_dir + "w_cbar_centered",
    vmin=vmin,
    vmax=vmax
)

In [None]:
plot_heatmap_histogram(
    m=w,
    cmap="PuOr", # "PuOr"-weights,  "RdGy"-bias
    xlabels=w_df.columns,
    ylabels=w_df.index,
    path=plot_out_dir + "w_hist_clean",
    vmin=vmin,
    vmax=vmax
)

In [None]:
# colorbar
if auto_scale:
    vmin, vmax = get_auto_scale("b", b)
else:
    vmin, vmax = -1.0, 1.0

print("bias min, max:", vmin, vmax)

In [None]:
plot_heatmap_colorbar(
    m=b,
    cmap="RdGy", # "PuOr"-weights,  "RdGy"-bias
    xlabels=b_df.columns,
    ylabels=b_df.index,
    path=plot_out_dir + "b_cbar_clean",
    vmin=vmin,
    vmax=vmax
)

In [None]:
plot_heatmap_histogram(
    m=b,
    cmap="RdGy", # "PuOr"-weights,  "RdGy"-bias
    xlabels=b_df.columns,
    ylabels=b_df.index,
    path=plot_out_dir + "b_hist_clean",
    vmin=vmin,
    vmax=vmax
)

# Target Layer

In [None]:
try:
    t_w_df = pd.read_csv(rundir + "target_w.csv", index_col=0)
    t_w = t_w_df.values

    if auto_scale:
        vmin, vmax = get_auto_scale("w", t_w, center=False)
    else:
        vmin, vmax = -0.5, 2.5
    plot_heatmap_colorbar(
        m=t_w,
        cmap="PuOr", # "PuOr"-weights,  "RdGy"-bias
        xlabels=t_w_df.columns,
        ylabels=t_w_df.index,
        path=plot_out_dir + "target_w_cbar",
        vmin=vmin,
        vmax=vmax
    )

except:
    print("no target weight found")

In [None]:
try:
    t_b_df = pd.read_csv(rundir + "target_b.csv", index_col=0)
    t_b = t_b_df.values

    # temporary
    t_b = t_b - 1
    # temporary

    if auto_scale:
        vmin, vmax = get_auto_scale("b", t_b, center=False)
    else:
        vmin, vmax = -1.5, 1.5

    plot_heatmap_colorbar(
        m=t_b,
        cmap="RdGy", # "PuOr"-weights,  "RdGy"-bias
        xlabels=t_b_df.columns,
        ylabels=t_b_df.index,
        path=plot_out_dir + "target_b_cbar",
        vmin=vmin,
        vmax=vmax
    )

except:
    print("no target bias found")