In [1]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash 
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

  from pandas import Panel


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
gin.parse_config_file("config.gin")

tf.keras.backend.clear_session()

problem = lib_biased_mnist.BiasedMnistProblem()

in_dist = tf.data.Dataset.from_tensor_slices(
    problem.filter_tensors(
        *lib_biased_mnist.get_biased_mnist_data("~/.datasets/mnist/", 1.0, train=False)
    )
).cache()

oo_dist = tf.data.Dataset.from_tensor_slices(
    problem.filter_tensors(
        *lib_biased_mnist.get_biased_mnist_data("~/.datasets/mnist/", 0.1, train=False)
    )
).cache()

In [4]:
def model_generalization_error(X, y, y_biased, y_hat):
    per_digit_accuracy = lib_analysis.make_confusion_matrix(X, y, y_biased, y_hat)
    errors = []
    for bg in range(10):
        per_digit = [None] * 2

        for digit in range(2):
            n_correct, n_total = per_digit_accuracy[digit, bg]
            if n_total == 0:
                per_digit[digit] = 0.0
            else:
                per_digit[digit] = n_correct / n_total

        e = (per_digit[0] - per_digit[1]) ** 2.0
        errors.append(e)
    mse = np.mean(errors)
    return mse

In [5]:
def process_row(row: pd.DataFrame, root: Path):
    models = []
    for p in row["model_paths"]:
        p = root / Path(p)
        m = tf.keras.models.load_model(p, compile=False)
        models.append(m)

    batch_size = 1024
    # in_X, in_y, in_y_biased, in_y_hat = process_dataset(in_dist, models, batch_size)
    oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(oo_dist, models, batch_size)

    ret = {}
    per_network = sorted(
        [
            model_generalization_error(
                oo_X, oo_y, oo_y_biased, tf.math.argmax(oo_y_hat, axis=-1)[:, im]
            )
            for im in range(2)
        ]
    )
    for im in range(2):
        ret[f"network_{im}_generalization_error"] = per_network[im]

    ret["ensemble_generalization_error"] = model_generalization_error(
        oo_X,
        oo_y,
        oo_y_biased,
        tf.math.argmax(tf.math.reduce_mean(oo_y_hat, axis=1), axis=-1),
    )
    return ret

In [6]:
def add_generalization_error(df: pd.DataFrame, root: Path) -> pd.DataFrame:
    genzation = df.progress_apply(lambda row: process_row(row, root), axis=1)
    keys = genzation[0].keys()
    d = {k: [] for k in keys}

    for g in genzation:
        for k in keys:
            d[k].append(g[k])

    df = pd.concat([df, pd.DataFrame(d)], axis=1)
    return df

In [7]:
def expand_generalization_rows(row: pd.DataFrame):
    df = []
    for col in ["network_0", "network_1", "ensemble"]:
        row_name = f"{col}_generalization_error"
        d = row[row_name]
        r = row.copy()
        r["generalization_error"] = d
        r["generalization_error_model"] = col
        df.append(r)
    df = pd.DataFrame(df)
    for col in ["network_0", "network_1", "ensemble"]:
        row_name = f"{col}_generalization_error"
        df = df.drop(row_name, axis=1)
    return df

In [8]:
data_root = Path("../data/33bdc")

In [9]:
DF = lib_analysis.read_problem(data_root, "biased_mnist")

DF = DF[DF["BiasedMnistProblem.background_noise_level"] == 0].reset_index(drop=True)

DF = add_generalization_error(DF, data_root)
DF = pd.concat(list(DF.progress_apply(expand_generalization_rows, axis=1))).reset_index(
    drop=True
)

100%|██████████| 288/288 [02:31<00:00,  1.90it/s]
100%|██████████| 288/288 [00:02<00:00, 124.07it/s]


In [10]:
def statistics_for_series(df: pd.Series):
    models = []
    for p in df["model_paths"]:
        p = data_root / Path(p)
        m = tf.keras.models.load_model(p, compile=False)
        models.append(m)

    batch_size = 1024
    # in_X, in_y, in_y_biased, in_y_hat = process_dataset(in_dist, models, batch_size)
    oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(oo_dist, models, batch_size)
    return lib_analysis.print_statistics(oo_X, oo_y, oo_y_biased, oo_y_hat)

In [11]:
top_level_tabs = []

for bg_level in np.unique(DF["BiasedMnistProblem.background_noise_level"]):
    tabs_for_bg_level = []
    for l_corr in np.unique(sorted(list(DF["label_correlation"]), reverse=True)):

        def stats_for_lambda(lambda_):
            df = (
                DF[
                    (DF["label_correlation"] == l_corr)
                    & (DF["indep"] == "conditional_hsic")
                    & (DF["lambda"] == lambda_)
                    & (DF["BiasedMnistProblem.background_noise_level"] == bg_level)
                ]
                .sample(1)
                .reset_index(drop=True)
            )

            df = df.iloc[0]

            models = []
            for p in df["model_paths"]:
                p = data_root / Path(p)
                m = tf.keras.models.load_model(p, compile=False)
                models.append(m)

            batch_size = 1024
            # in_X, in_y, in_y_biased, in_y_hat = process_dataset(in_dist, models, batch_size)
            oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(
                oo_dist, models, batch_size
            )
            return lib_analysis.print_statistics(oo_X, oo_y, oo_y_biased, oo_y_hat)

        fig = lib_plot.end_to_end_plot(
            go.Figure(),
            DF[
                (DF["label_correlation"] == l_corr)
                & (DF["BiasedMnistProblem.background_noise_level"] == bg_level)
            ].copy(),
            "generalization_error",
            ["generalization_error_model", "indep"],
            f"MSE between per-digit accuracies for label correlation {l_corr}; background noise level {bg_level}",
        )

        per_lambda_stats = html.Div(
            [
                dcc.Tabs(
                    [
                        dcc.Tab(
                            label="Vanilla Ensemble", children=stats_for_lambda(0),
                        ),
                        dcc.Tab(label="Lambda = 1", children=stats_for_lambda(1)),
                    ]
                )
            ]
        )

        tabs_for_bg_level.append(
            dcc.Tab(
                label=f"label correlation {l_corr}",
                children=[dcc.Graph(figure=fig), per_lambda_stats],
            )
        )

    top_level_tabs.append(
        dcc.Tab(
            label=f"background noise level {bg_level}",
            children=[dcc.Tabs(tabs_for_bg_level)],
        )
    )

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app.layout = html.Div([dcc.Tabs(top_level_tabs)])
app.run_server(mode="inline")

In [12]:
def visualize_weights(model: tf.keras.Model):
    (feature_extractor, logistic) = model.layers[-2:]
    [kernel, bias] = feature_extractor.get_weights()

    n_hidden_neurons = kernel.shape[-1]

    as_images = np.resize(kernel, (28, 28, 3, n_hidden_neurons))

    f, axarr = plt.subplots(n_hidden_neurons, 3, figsize=(15, 15))

    f.tight_layout()

    for i in range(n_hidden_neurons):
        for c in range(3):
            for_neuron = as_images[:, :, c, i]

            for_neuron = np.abs(for_neuron)

            axarr[i, c].imshow(for_neuron, cmap="gray")
            axarr[i, c].axis("off")

    plt.subplots_adjust(wspace=0.01, hspace=0.0001, right=10, left=9.3)
    plt.show()