In [1]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from functools import reduce
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

  from pandas import Panel


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def read_df(data_root: Path, experiment_name: str) -> pd.DataFrame:
    df_path = data_root / experiment_name / "df.pickle"
    if df_path.exists():
        print(f"Found cached df at {df_path}. Reusing...", flush=True)
        return pd.read_pickle(str(df_path))

    def filter_for_column_values(
        df: pd.DataFrame, col_name: str, values: List
    ) -> pd.DataFrame:
        per_value_bitmask = [df[col_name] == v for v in values]
        combined_bitmask = reduce(lambda x, y: x | y, per_value_bitmask)
        return df[combined_bitmask]

    DF = lib_analysis.read_problem(data_root, experiment_name)
    DF = filter_for_column_values(DF, "Problem.initial_lr", [0.0001, 0.001, 0.01])
    DF = filter_for_column_values(DF, "label_correlation", [0.9, 0.99, 0.999])

    DF = lib_analysis.add_statistics_to_df(DF)

    DF.to_pickle(str(df_path), protocol=4)
    return DF

In [None]:
data_root = Path("../data/")
DF = read_df(data_root, "biased_mnist")

 76%|███████▌  | 864/1134 [1:20:46<21:51,  4.86s/it]  

In [None]:
def statistics_for_series(row: pd.Series):
    return [
        html.Div(
            [
                html.H4("Out of Distribution"),
                *lib_analysis.format_statistics(row["ood_statistics"]),
            ]
        ),
        html.Div(
            [
                html.H4("In Distribution"),
                *lib_analysis.format_statistics(row["id_statistics"]),
            ]
        ),
    ]

In [None]:
top_level_tabs = []

indep = "conditional_cka"

for lr in [0.0001, 0.001, 0.01]:
    print(f"Processing data for lr {lr}")
    tabs_for_outer = []
    for l_corr in [0.9, 0.99, 0.999]:

        def stats_for_lambda(lambda_: float):
            df = (
                DF[
                    (DF["label_correlation"] == l_corr)
                    & (DF["indep"] == indep)
                    & (DF["lambda"] == lambda_)
                    & (DF["Problem.initial_lr"] == lr)
                ]
                .sample(1)
                .reset_index(drop=True)
            )
            return statistics_for_series(df.iloc[0])

        df_for_fig = DF[
            (DF["label_correlation"] == l_corr) & (DF["Problem.initial_lr"] == lr)
        ].copy()

        fig_train_diversity = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "train_diversity_loss",
            ["indep"],
            f"Train diversity loss for label correlation {l_corr}; learning rate {lr}",
        )

        fig_id_disagreement = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "id_disagreement_rate",
            ["indep"],
            f"ID Disagreement rate for label correlation {l_corr}; learning rate {lr}",
        )

        fig_ood_disagreement = lib_plot.end_to_end_plot(
            go.Figure(),
            df_for_fig,
            "ood_disagreement_rate",
            ["indep"],
            f"OOD Disagreement rate for label correlation {l_corr}; learning rate {lr}",
        )

        per_lambda_stats = html.Div(
            [
                dcc.Tabs(
                    [
                        dcc.Tab(
                            label=f"Lambda = {lambda_}",
                            children=stats_for_lambda(lambda_),
                        )
                        for lambda_ in np.unique(DF["lambda"])
                    ]
                )
            ]
        )

        tabs_for_outer.append(
            dcc.Tab(
                label=f"label correlation {l_corr}",
                children=[
                    dcc.Graph(figure=fig_train_diversity),
                    dcc.Graph(figure=fig_id_disagreement),
                    dcc.Graph(figure=fig_ood_disagreement),
                    per_lambda_stats,
                ],
            )
        )

    top_level_tabs.append(
        dcc.Tab(label=f"learning rate {lr}", children=[dcc.Tabs(tabs_for_outer)],)
    )

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app.layout = html.Div([dcc.Tabs(top_level_tabs)])
app.run_server(mode="inline")