In [2]:
import argparse
import concurrent.futures
import json
import os
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import dash_core_components as dcc
import dash_html_components as html
import gin
import lib_analysis
import lib_biased_mnist
import lib_plot
import lib_problem
import lib_toy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import sklearn
import sklearn.metrics
import tensorflow as tf
import tensorflow_probability as tfp
import yaml
from jupyter_dash import JupyterDash
from PIL import Image
from tqdm import tqdm

external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

tqdm.pandas()
gin.enter_interactive_mode()

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
def read_df(data_root: Path, experiment_name: str) -> pd.DataFrame:
    df_path = data_root / experiment_name / "df.pickle"
    if df_path.exists():
        print(f"Found cached df at {df_path}. Reusing...", flush=True)
        return pd.read_pickle(str(df_path))

    DF = lib_analysis.read_problem(data_root, experiment_name)
    DF = lib_analysis.add_statistics_to_df(DF)

    DF.to_pickle(str(df_path), protocol=4)
    return DF

In [5]:
data_root = Path("../data/")
DF = read_df(data_root, "biased_mnist")

gin.parse_config_file(
    DF["gin_config_file"].iloc[0]
)

problem = lib_biased_mnist.BiasedMnistProblem()
in_dist = problem.generate_id_testing_data(include_bias=True)
oo_dist = problem.generate_ood_testing_data(include_bias=True)

Found cached df at ../data/biased_mnist/df.pickle. Reusing...


In [6]:
def compute_ensemble_distribution(y_hat):
    m_1, m_2 = y_hat[:, 0, :], y_hat[:, 1, :]
    y_hat_ensemble = (m_1 + m_2) / 2
    return y_hat_ensemble


def pmax(y_hat_ensemble, _y_hat):
    return tf.math.reduce_max(y_hat_ensemble, axis=1)


def entropy(y_hat_ensemble, _y_hat):
    return -tf.math.reduce_sum(y_hat_ensemble * tf.math.log(y_hat_ensemble), axis=1)


def max_diff(y_hat_ensemble, y_hat):
    per_model = []
    for im in range(2):
        per_model.append(
            tf.math.reduce_max(tf.math.abs(y_hat_ensemble - y_hat[:, im, :]), axis=1)
        )
    ret = tf.math.reduce_max(per_model, axis=0)
    return ret


def average_diff(y_hat_ensemble, y_hat):
    per_model = []
    for im in range(2):
        per_model.append(
            tf.math.reduce_mean(tf.math.abs(y_hat_ensemble - y_hat[:, im, :]), axis=1)
        )
    ret = tf.math.reduce_mean(per_model, axis=0)
    return ret

In [7]:
def compute_auroc(in_y_hat, oo_y_hat, f_ensemble_score, higher_score_is_ood=True):
    id_score = f_ensemble_score(compute_ensemble_distribution(in_y_hat), in_y_hat)
    oo_score = f_ensemble_score(compute_ensemble_distribution(oo_y_hat), oo_y_hat)

    if higher_score_is_ood:
        id_label, oo_label = 0, 1
    else:
        id_label, oo_label = 1, 0

    y_true = tf.convert_to_tensor(
        ([id_label] * id_score.shape[0]) + ([oo_label] * oo_score.shape[0])
    )
    y_score = tf.concat([id_score, oo_score], axis=0)
    return sklearn.metrics.roc_auc_score(y_true, y_score)

In [8]:
def compute_metrics(in_y_hat, oo_y_hat):
    return {
        "pmax": compute_auroc(in_y_hat, oo_y_hat, pmax, higher_score_is_ood=False),
        "entropy": compute_auroc(in_y_hat, oo_y_hat, entropy),
        "max_diff": compute_auroc(in_y_hat, oo_y_hat, max_diff),
        "average_diff": compute_auroc(in_y_hat, oo_y_hat, average_diff),
    }

In [9]:
def compute_metrics_for_row(row: pd.Series):
    models = []
    for p in row["model_paths"]:
        m = tf.keras.models.load_model(p, compile=False)
        models.append(m)

    in_X, in_y, in_y_biased, in_y_hat = lib_analysis.process_dataset(in_dist, models)
    oo_X, oo_y, oo_y_biased, oo_y_hat = lib_analysis.process_dataset(oo_dist, models)
    return compute_metrics(in_y_hat, oo_y_hat)

In [10]:
DF = lib_analysis.add_columns_to_df(
    DF, DF.progress_apply(compute_metrics_for_row, axis=1)
)

 22%|██▏       | 253/1134 [06:51<23:53,  1.63s/it]


ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

In [None]:
DF.to_pickle('/home/ericpts/work/data/biased_mnist_10_digits/df.pickle')