In [1]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

  from pandas import Panel


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def read_df(data_root: Path, experiment_name: str) -> pd.DataFrame:
    df_path = data_root / experiment_name / "df.pickle"
    if df_path.exists():
        print(f"Found cached df at {df_path}. Reusing...", flush=True)
        return pd.read_pickle(str(df_path))

    DF = lib_analysis.read_problem(data_root, experiment_name)
    DF = lib_analysis.add_statistics_to_df(DF)

    DF.to_pickle(str(df_path), protocol=4)
    return DF

In [5]:
data_root = Path("../data/")
DF = read_df(data_root, "biased_mnist_10_digits")

Found cached df at ../data/biased_mnist_10_digits/df.pickle. Reusing...


In [6]:
def add_column_from_statistics(
    df: pd.DataFrame, f_per_row: Callable[[Dict], Dict]
) -> pd.DataFrame:
    df = df.copy()
    df_new = df[["id_statistics", "ood_statistics"]].apply(f_per_row, axis=1)

    df_new_t = {}
    for x in df_new:
        for k, v in x.items():
            if k not in df_new_t:
                df_new_t[k] = []
            df_new_t[k].append(v)
    df_new_t = pd.DataFrame(df_new_t)
    return pd.concat([df.reset_index(drop=True), df_new_t], axis=1)


def f_extract_disagreement_per_row(row: Dict) -> pd.Series:
    ood_dis = row["ood_statistics"]["disagreement"]
    ood_disagreement_rate = ood_dis["n_select"].numpy() / ood_dis["n_original"]

    id_dis = row["id_statistics"]["disagreement"]
    if "n_select" not in id_dis:
        id_disagreement_rate = 0.0
    else:
        id_disagreement_rate = id_dis["n_select"].numpy() / id_dis["n_original"]

    return {
        "ood_disagreement_rate": ood_disagreement_rate,
        "id_disagreement_rate": id_disagreement_rate,
    }


add_column_from_statistics(DF, f_extract_disagreement_per_row).sample()

Unnamed: 0,test_prediction_loss,test_accuracy,test_diversity_loss,test_combined_loss,test_ensemble_accuracy,train_prediction_loss,train_accuracy,train_diversity_loss,train_combined_loss,train_ensemble_accuracy,...,Problem.initial_lr,Problem.decrease_lr_at_epochs,Problem.n_models,get_weight_regularizer.strength,yaml_config_file,gin_config_file,ood_statistics,id_statistics,ood_disagreement_rate,id_disagreement_rate
1697,"[22.31837272644043, 11.712209701538086]","[0.09960000216960907, 0.25619998574256897]",0.196011,37.166756,0.1607,"[0.030147647485136986, 0.0073644788935780525]","[0.9989500045776367, 0.9991000294685364]",0.185107,2.999228,0.9991,...,0.01,"[20, 40, 80]",2,0.01,/home/ericpts/work/data/biased_mnist/cka/mlp/l...,/home/ericpts/work/data/biased_mnist/cka/mlp/l...,{'overall': {'ensemble': {'accuracy': tf.Tenso...,{'overall': {'ensemble': {'accuracy': tf.Tenso...,0.3573,0.0012


In [7]:
def statistics_for_series(row: pd.Series):
    return [
        html.Div(
            [
                html.H4("Out of Distribution"),
                *lib_analysis.format_statistics(row["ood_statistics"]),
            ]
        ),
        html.Div(
            [
                html.H4("In Distribution"),
                *lib_analysis.format_statistics(row["id_statistics"]),
            ]
        ),
    ]

In [12]:
top_level_tabs = []

indep = "cka"

for lr in np.unique(DF["Problem.initial_lr"])[:3]:
    print(f"Processing data for lr {lr}")
    tabs_for_outer = []
    for l_corr in np.unique(DF["label_correlation"]):

        def stats_for_lambda(lambda_: float):
            df = (
                DF[
                    (DF["label_correlation"] == l_corr)
                    & (DF["indep"] == indep)
                    & (DF["lambda"] == lambda_)
                    & (DF["Problem.initial_lr"] == lr)
                ]
                .sample(1)
                .reset_index(drop=True)
            )
            return statistics_for_series(df.iloc[0])

        df_for_fig = DF[
            (DF["label_correlation"] == l_corr) & (DF["Problem.initial_lr"] == lr)
        ].copy()

        fig_train_diversity = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "train_diversity_loss",
            ["indep"],
            f"Train diversity loss for label correlation {l_corr}; learning rate {lr}",
        )

        fig_id_disagreement = lib_plot.end_to_end_plot(
            go.Figure(),
            add_column_from_statistics(df_for_fig, f_extract_disagreement_per_row),
            "id_disagreement_rate",
            ["indep"],
            f"ID Disagreement rate for label correlation {l_corr}; learning rate {lr}",
        )

        fig_ood_disagreement = lib_plot.end_to_end_plot(
            go.Figure(),
            add_column_from_statistics(df_for_fig, f_extract_disagreement_per_row),
            "ood_disagreement_rate",
            ["indep"],
            f"OOD Disagreement rate for label correlation {l_corr}; learning rate {lr}",
        )

        per_lambda_stats = html.Div(
            [
                dcc.Tabs(
                    [
                        dcc.Tab(
                            label=f"Lambda = {lambda_}",
                            children=stats_for_lambda(lambda_),
                        )
                        for lambda_ in np.unique(DF["lambda"])
                    ]
                )
            ]
        )

        tabs_for_outer.append(
            dcc.Tab(
                label=f"label correlation {l_corr}",
                children=[
                    dcc.Graph(figure=fig_train_diversity),
                    dcc.Graph(figure=fig_id_disagreement),
                    dcc.Graph(figure=fig_ood_disagreement),
                    per_lambda_stats,
                ],
            )
        )

    top_level_tabs.append(
        dcc.Tab(label=f"learning rate {lr}", children=[dcc.Tabs(tabs_for_outer)],)
    )

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app.layout = html.Div([dcc.Tabs(top_level_tabs)])
app.run_server(mode="inline")

Processing data for lr 0.0001
Processing data for lr 0.001
Processing data for lr 0.01
