In [1]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

  from pandas import Panel


In [2]:
%load_ext autoreload
%autoreload 2

In [5]:
gin.parse_config_file("one-off/config.gin")

tf.keras.backend.clear_session()

problem = lib_biased_mnist.BiasedMnistProblem()

in_dist = tf.data.Dataset.from_tensor_slices(
    problem.filter_tensors(
        *lib_biased_mnist.get_biased_mnist_data("~/.datasets/mnist/", 1.0, train=False)
    )
).cache()

oo_dist = tf.data.Dataset.from_tensor_slices(
    problem.filter_tensors(
        *lib_biased_mnist.get_biased_mnist_data("~/.datasets/mnist/", 0.1, train=False)
    )
).cache()

In [6]:
def model_generalization_error(X, y, y_biased, y_hat):
    per_digit_accuracy = lib_analysis.make_confusion_matrix(X, y, y_biased, y_hat)
    errors = []
    for bg in range(10):
        per_digit = [None] * 2

        for digit in range(2):
            n_correct, n_total = per_digit_accuracy[digit, bg]
            if n_total == 0:
                per_digit[digit] = 0.0
            else:
                per_digit[digit] = n_correct / n_total

        e = (per_digit[0] - per_digit[1]) ** 2.0
        errors.append(e)
    mse = np.mean(errors)
    return mse

In [7]:
def process_row(row: pd.DataFrame, root: Path):
    models = []
    for p in row["model_paths"]:
        p = root / Path(p)
        m = tf.keras.models.load_model(p, compile=False)
        models.append(m)

    batch_size = 1024
    # in_X, in_y, in_y_biased, in_y_hat = process_dataset(in_dist, models, batch_size)
    oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(
        oo_dist, models, batch_size
    )

    ret = {}
    per_network = sorted(
        [
            model_generalization_error(
                oo_X, oo_y, oo_y_biased, tf.math.argmax(oo_y_hat, axis=-1)[:, im]
            )
            for im in range(2)
        ]
    )
    for im in range(2):
        ret[f"network_{im}_generalization_error"] = per_network[im]

    ret["ensemble_generalization_error"] = model_generalization_error(
        oo_X,
        oo_y,
        oo_y_biased,
        tf.math.argmax(tf.math.reduce_mean(oo_y_hat, axis=1), axis=-1),
    )
    return ret

In [8]:
def add_generalization_error(df: pd.DataFrame, root: Path) -> pd.DataFrame:
    genzation = df.progress_apply(lambda row: process_row(row, root), axis=1)
    keys = genzation[0].keys()
    d = {k: [] for k in keys}

    for g in genzation:
        for k in keys:
            d[k].append(g[k])

    df = pd.concat([df, pd.DataFrame(d)], axis=1)
    return df

In [9]:
def expand_generalization_rows(row: pd.DataFrame):
    df = []
    for col in ["network_0", "network_1", "ensemble"]:
        row_name = f"{col}_generalization_error"
        d = row[row_name]
        r = row.copy()
        r["generalization_error"] = d
        r["generalization_error_model"] = col
        df.append(r)
    df = pd.DataFrame(df)
    for col in ["network_0", "network_1", "ensemble"]:
        row_name = f"{col}_generalization_error"
        df = df.drop(row_name, axis=1)
    return df

In [10]:
data_root = Path("../data/3c2e4")

In [11]:
DF = lib_analysis.read_problem(data_root, "biased_mnist")
DF = add_generalization_error(DF, data_root)
DF = pd.concat(list(DF.progress_apply(expand_generalization_rows, axis=1))).reset_index(
    drop=True
)

100%|██████████| 720/720 [03:57<00:00,  3.04it/s]
100%|██████████| 720/720 [00:06<00:00, 106.59it/s]


In [23]:
def statistics_for_series(df: pd.Series):
    models = []
    for p in df["model_paths"]:
        p = data_root / Path(p)
        m = tf.keras.models.load_model(p, compile=False)
        models.append(m)

    batch_size = 1024
    # in_X, in_y, in_y_biased, in_y_hat = process_dataset(in_dist, models, batch_size)
    oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(
        oo_dist, models, batch_size
    )
    return lib_analysis.print_statistics(oo_X, oo_y, oo_y_biased, oo_y_hat)

In [27]:
top_level_tabs = []

for lr in np.unique(DF["Problem.initial_lr"]):
    tabs_for_outer = []
    for l_corr in np.unique(sorted(list(DF["label_correlation"]), reverse=True)):

        def stats_for_lambda(lambda_):
            df = (
                DF[
                    (DF["label_correlation"] == l_corr)
                    & (DF["indep"] == "conditional_hsic")
                    & (DF["lambda"] == lambda_)
                    & (DF["Problem.initial_lr"] == lr)
                ]
                .sample(1)
                .reset_index(drop=True)
            )
            return statistics_for_series(df.iloc[0])

        df_for_fig = DF[
            (DF["label_correlation"] == l_corr)
            & (DF["Problem.initial_lr"] == lr)
            & (DF["generalization_error_model"] == "ensemble")
        ].copy()

        fig = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "generalization_error",
            ["generalization_error_model", "indep"],
            f"MSAcc between per-digit accuracies for label correlation {l_corr}; learning rate {lr}",
        )

        fig_diversity = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "train_diversity_loss",
            ["indep"],
            f"Train diversity loss for label correlation {l_corr}; learning rate {lr}",
        )
        per_lambda_stats = html.Div(
            [
                dcc.Tabs(
                    [
                        dcc.Tab(
                            label=f"Lambda = {lambda_}",
                            children=stats_for_lambda(lambda_),
                        )
                        for lambda_ in sorted(np.unique(DF["lambda"]))
                    ]
                )
            ]
        )

        tabs_for_outer.append(
            dcc.Tab(
                label=f"label correlation {l_corr}",
                children=[
                    dcc.Graph(figure=fig),
                    dcc.Graph(figure=fig_diversity),
                    per_lambda_stats,
                ],
            )
        )

    top_level_tabs.append(
        dcc.Tab(label=f"learning rate {lr}", children=[dcc.Tabs(tabs_for_outer)],)
    )

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app.layout = html.Div([dcc.Tabs(top_level_tabs)])
app.run_server(mode="inline")

In [22]:
DF

Unnamed: 0,test_prediction_loss,test_accuracy,test_diversity_loss,test_combined_loss,test_ensemble_accuracy,train_prediction_loss,train_accuracy,train_diversity_loss,train_combined_loss,train_ensemble_accuracy,...,BiasedMnistProblem.model_type,BiasedMnistProblem.background_noise_level,Problem.n_epochs,Problem.initial_lr,Problem.decrease_lr_at_epochs,Problem.n_models,get_weight_regularizer.strength,original_config,generalization_error,generalization_error_model
0,"[0.18410581350326538, 2.0631585121154785]","[0.944208025932312, 0.547517716884613]",1.109123,73.231110,0.806619,"[0.0006089136586524546, 0.008913685567677021]","[0.9996841549873352, 0.9982629418373108]",0.006149,0.403069,0.999131,...,mlp,0,100,0.01,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.045813,network_0
1,"[0.18410581350326538, 2.0631585121154785]","[0.944208025932312, 0.547517716884613]",1.109123,73.231110,0.806619,"[0.0006089136586524546, 0.008913685567677021]","[0.9996841549873352, 0.9982629418373108]",0.006149,0.403069,0.999131,...,mlp,0,100,0.01,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.900042,network_1
2,"[0.18410581350326538, 2.0631585121154785]","[0.944208025932312, 0.547517716884613]",1.109123,73.231110,0.806619,"[0.0006089136586524546, 0.008913685567677021]","[0.9996841549873352, 0.9982629418373108]",0.006149,0.403069,0.999131,...,mlp,0,100,0.01,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.302390,ensemble
3,"[1.1909431219100952, 1.472009539604187]","[0.778723418712616, 0.5924350023269653]",1.323529,87.368805,0.698818,"[0.00471786642447114, 0.006201746873557568]","[0.9992104172706604, 0.9984998106956482]",0.006343,0.416851,0.998658,...,mlp,0,100,0.01,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.342523,network_0
4,"[1.1909431219100952, 1.472009539604187]","[0.778723418712616, 0.5924350023269653]",1.323529,87.368805,0.698818,"[0.00471786642447114, 0.006201746873557568]","[0.9992104172706604, 0.9984998106956482]",0.006343,0.416851,0.998658,...,mlp,0,100,0.01,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.716972,network_1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2155,"[33.45603561401367, 9.348630905151367]","[0.9659574627876282, 0.9659574627876282]",0.913605,42.804668,0.969740,"[0.0, 0.0]","[1.0, 1.0]",0.859562,0.000000,1.000000,...,mlp,0,100,1.00,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.036766,network_1
2156,"[33.45603561401367, 9.348630905151367]","[0.9659574627876282, 0.9659574627876282]",0.913605,42.804668,0.969740,"[0.0, 0.0]","[1.0, 1.0]",0.859562,0.000000,1.000000,...,mlp,0,100,1.00,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.011122,ensemble
2157,"[44.6565055847168, 100.33548736572266]","[0.9385342597961426, 0.9456264972686768]",0.853298,144.992004,0.938061,"[0.0, 0.0]","[1.0, 1.0]",0.807168,0.000000,1.000000,...,mlp,0,100,1.00,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.022644,network_0
2158,"[44.6565055847168, 100.33548736572266]","[0.9385342597961426, 0.9456264972686768]",0.853298,144.992004,0.938061,"[0.0, 0.0]","[1.0, 1.0]",0.807168,0.000000,1.000000,...,mlp,0,100,1.00,"[20, 40, 80]",2,0.01,../data/3c2e4/biased_mnist/conditional_hsic/ml...,0.037955,network_1


In [None]:
def visualize_weights(model: tf.keras.Model):
    (feature_extractor, logistic) = model.layers[-2:]
    [kernel, bias] = feature_extractor.get_weights()

    n_hidden_neurons = kernel.shape[-1]

    as_images = np.resize(kernel, (28, 28, 3, n_hidden_neurons))

    f, axarr = plt.subplots(n_hidden_neurons, 3, figsize=(15, 15))

    f.tight_layout()

    for i in range(n_hidden_neurons):
        for c in range(3):
            for_neuron = as_images[:, :, c, i]

            for_neuron = np.abs(for_neuron)

            axarr[i, c].imshow(for_neuron, cmap="gray")
            axarr[i, c].axis("off")

    plt.subplots_adjust(wspace=0.01, hspace=0.0001, right=10, left=9.3)
    plt.show()