In [1]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

  from pandas import Panel


In [2]:
PROBLEM = "biased_mnist"

In [3]:
%load_ext autoreload
%autoreload 2

In [5]:
def read_df(data_root: Path) -> pd.DataFrame:
    df_path = data_root / PROBLEM / "df.pickle"
    if df_path.exists():
        print(f"Found cached df at {df_path}. Reusing...", flush=True)
        return pd.read_pickle(str(df_path))

    DF = lib_analysis.read_problem(data_root, PROBLEM)
    DF = lib_analysis.add_statistics_to_df(DF)

    DF.to_pickle(str(df_path))
    return DF

In [6]:
data_root = Path("../data/")
DF = read_df(data_root)

Found cached df at ../data/biased_mnist/df.pickle. Reusing...


In [8]:
DF.sample()

Unnamed: 0,test_prediction_loss,test_accuracy,test_diversity_loss,test_combined_loss,test_ensemble_accuracy,train_prediction_loss,train_accuracy,train_diversity_loss,train_combined_loss,train_ensemble_accuracy,...,BiasedMnistProblem.background_noise_level,Problem.n_epochs,Problem.initial_lr,Problem.decrease_lr_at_epochs,Problem.n_models,get_weight_regularizer.strength,yaml_config_file,gin_config_file,ood_statistics,id_statistics
888,"[2.29874849319458, 5.004456043243408]","[0.11429999768733978, 0.31529998779296875]",0.007118,9.125334,0.3153,"[2.2937798500061035, 0.04043296352028847]","[0.11273333430290222, 0.9943666458129883]",0.002405,2.949967,0.994367,...,0,100,0.001,"[20, 40, 80]",2,0.01,/home/ericpts/work/data/biased_mnist/condition...,/home/ericpts/work/data/biased_mnist/condition...,{'overall': {'ensemble': {'accuracy': tf.Tenso...,{'overall': {'ensemble': {'accuracy': tf.Tenso...


In [12]:
problem = lib_biased_mnist.BiasedMnistProblem()
in_dist = problem.generate_id_testing_data(include_bias=True)
oo_dist = problem.generate_ood_testing_data(include_bias=True)


def statistics_for_series(row: pd.Series):
    return [
        html.Div(
            [html.H4("Out of Distribution"), *lib_analysis.format_statistics(row["ood_statistics"])]
        ),
        html.Div(
            [html.H4("In Distribution"), *lib_analysis.format_statistics(row["id_statistics"])]
        ),
    ]

In [11]:
top_level_tabs = []

for lr in np.unique(DF["Problem.initial_lr"]):
    print(f"Processing data for lr {lr}")
    tabs_for_outer = []
    for l_corr in np.unique(DF["label_correlation"]):
        def stats_for_lambda(lambda_: float):
            df = (
                DF[
                    (DF["label_correlation"] == l_corr)
                    & (DF["indep"] == "conditional_hsic")
                    & (DF["lambda"] == lambda_)
                    & (DF["Problem.initial_lr"] == lr)
                ]
                .sample(1)
                .reset_index(drop=True)
            )
            return statistics_for_series(df.iloc[0])

        df_for_fig = DF[
            (DF["label_correlation"] == l_corr) & (DF["Problem.initial_lr"] == lr)
        ].copy()

        fig_train_diversity = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "train_diversity_loss",
            ["indep"],
            f"Train diversity loss for label correlation {l_corr}; learning rate {lr}",
        )
        per_lambda_stats = html.Div(
            [
                dcc.Tabs(
                    [
                        dcc.Tab(
                            label=f"Lambda = {lambda_}",
                            children=stats_for_lambda(lambda_),
                        )
                        for lambda_ in np.unique(DF["lambda"])
                    ]
                )
            ]
        )

        tabs_for_outer.append(
            dcc.Tab(
                label=f"label correlation {l_corr}",
                children=[
                    dcc.Graph(figure=fig_train_diversity),
                    per_lambda_stats,
                ],
            )
        )

    top_level_tabs.append(
        dcc.Tab(label=f"learning rate {lr}", children=[dcc.Tabs(tabs_for_outer)],)
    )

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app.layout = html.Div([dcc.Tabs(top_level_tabs)])
app.run_server(mode="inline")

Processing data for lr 0.0001
Processing data for lr 0.001
Processing data for lr 0.01
Processing data for lr 0.1
Processing data for lr 1.0
