# Imports

In [7]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
from IPython.display import display, Image
import seaborn as sns
import pickle
from shutil import copy
from Loaders import Importer
from matplotlib import cm
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix
from scipy.ndimage.filters import gaussian_filter
from tqdm import tqdm
from uncertainties import ufloat

from src.utils import prepare_dfs, disp_general_around_thresh

%load_ext autoreload
%autoreload 2
%aimport src.utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Widgets for visual inspection of all 100 trained networks

**Fixed stats for all final training runs**

In [2]:
n_realisations = 10
n_slices = 100
main_root = pathlib.Path("./00_demonstrations_data/100_CNNs")
hparams_sets = [
    {'batch_size': 64,
    'lr': 1e-4,
    'perturb': True,
    },
]

In [3]:
prepare_dfs(hparams_sets, real_no=n_realisations, n_slices=n_slices, main_root=main_root)
folder_names = [f"bs={bs}_perturb={perturb}_lr={lr:.0e}" for bs, lr, perturb in zip([h['batch_size'] for h in hparams_sets], [h['lr'] for h in hparams_sets], [h['perturb'] for h in hparams_sets])]
possible_cols = hparams_sets[0]['df'].columns[2:][::-1]

**Analyze**

In [10]:
interact(
    disp_general_around_thresh,
    hparam_sets=fixed(hparams_sets[0]),
    # df=fixed(all_stats_df), 
    column_name=possible_cols, 
    above=[True, False], 
    head_number=fixed(20), 
    loc_in_head=(0, 19, 1),
    thresh_proc=(0.0, 1.0, 0.01),
    comp_axis_1=possible_cols, 
    comp_axis_2=possible_cols, 
    real_slice_no=fixed("10_100"), 
    sig=(0.1, 5.0, 0.1),
    shown_stat=['generalization', 'history', 'UMAP', 'PCA', 'CAMs'],
    cam_ind=(0, 99, 1),
);

interactive(children=(Dropdown(description='above', options=(True, False), value=True), Dropdown(description='…

**General statistics**

In [26]:
batch_size = 64
lr = 1e-4
clean_acc_thresh = 0
lpips_thresh = 0.18
rmse_thresh = 0.2
chosen_col = 'RMSE'
thresh = rmse_thresh

for perturb in [True]:
    d_tmp = {
        'batch_size': batch_size,
        'lr': lr,
        'perturb': perturb,
    }
    df_found = None
    for d in hparams_sets:
        if all([d_tmp[k] == d[k] for k in d_tmp.keys()]):
            df_found = d['df']
            break
    if df_found is None:
        raise ValueError('Dataframe not found')

    df_acc_constraint = df_found.where(df_found["Clean Accuracy"] >= clean_acc_thresh).dropna()
    df_generalizing = df_acc_constraint.where(df_acc_constraint[chosen_col] <= thresh).dropna()
    df_not_generalizing = df_acc_constraint.where(df_acc_constraint[chosen_col] > thresh).dropna()
    # print(df_generalizing['Accuracy'].max())
    print(f"Batch size {batch_size} | Perturb {perturb} | Learning Rate {lr} | Threshold ({chosen_col}) = {thresh}\n")
    print(f"Well-generalizing models\nPercentage: {len(df_generalizing)/len(df_found)*100:.1f}% | Test clean accuracy: {ufloat(df_generalizing['Clean Accuracy'].mean(), df_generalizing['Clean Accuracy'].std())} | OOD Accuracy: {ufloat(df_generalizing['Accuracy'].mean(), df_generalizing['Accuracy'].std())} | {chosen_col}: {ufloat(df_generalizing[chosen_col].mean(), df_generalizing[chosen_col].std())}")
    print(f"Poorly-generalizing models\nPercentage: {len(df_not_generalizing)/len(df_found)*100:.1f}% | Test clean accuracy: {ufloat(df_not_generalizing['Clean Accuracy'].mean(), df_not_generalizing['Clean Accuracy'].std())} | OOD Accuracy: {ufloat(df_not_generalizing['Accuracy'].mean(), df_not_generalizing['Accuracy'].std())} | {chosen_col}: {ufloat(df_not_generalizing[chosen_col].mean(), df_not_generalizing[chosen_col].std())}\n")

Batch size 64 | Perturb True | Learning Rate 0.0001 | Threshold (RMSE) = 0.2

Well-generalizing models
Percentage: 22.0% | Test clean accuracy: 0.947+/-0.019 | OOD Accuracy: 0.860+/-0.030 | RMSE: 0.153+/-0.031
Poorly-generalizing models
Percentage: 78.0% | Test clean accuracy: 0.970+/-0.023 | OOD Accuracy: 0.61+/-0.10 | RMSE: 0.41+/-0.10

