In [1]:
from glob import glob
from os.path import join
import pandas as pd

In [2]:
import numpy as np
import pickle

def load_pickle_safely(pickle_file):
    with open(pickle_file, 'rb') as f:
        return pickle.load(f)

def maybe_load_pickle(pickle_file):
    try:
        result = load_pickle_safely(pickle_file)
    except Exception:
        return None

def compute_means(draw_dict):
    
    return {x: y.mean(axis=(0, 1)) for x, y in draw_dict.items()}

def compute_sds(draw_dict):
    
    return {x: y.std(axis=(0, 1)) for x, y in draw_dict.items()}

def compute_z_score_mean(mean_dict, ref_mean_dict, ref_sd_dict):
    
    return {x: (mean_dict[x] - ref_mean_dict[x]) / ref_sd_dict[x] for x in mean_dict}

def compute_relative_error_sd(sd_dict, ref_sd_dict):

    return {x: (sd_dict[x] - ref_sd_dict[x]) / ref_sd_dict[x] for x in ref_sd_dict}

def flatten_dict(var_dict, names):

    return np.concatenate([var_dict[x].reshape(-1) for x in names])


In [3]:
from os.path import split, splitext

def load_moment_df(draw_folder):

    model_dicts = glob(join(draw_folder, "draw_dicts", "*.npz"))
    data = pd.DataFrame({"draw_dict_path": model_dicts})

    data["draws"] = data["draw_dict_path"].apply(lambda x: dict(np.load(x)))
    data["means"] = data["draws"].apply(compute_means)
    data["sds"] = data["draws"].apply(compute_sds)

    # No need for the draws any more for now
    data = data.drop(columns="draws")

    # Fetch model name
    data['model_name'] = data['draw_dict_path'].apply(lambda x: splitext(split(x)[-1])[0])

    return data

def check_convergence_sadvi(metadata, max_iter=100000):

    return metadata['steps'] < max_iter


def check_convergence_raabbvi(metadata, max_iter=19900):

    n_steps = metadata['kl_hist_i'].max()

    return n_steps < max_iter


def add_metadata(moment_df, method):

    assert method in ['NUTS', 'RAABBVI', 'DADVI', 'LRVB', 'SADVI', 'SADVI_FR', 'LRVB_Doubling']

    if method in ['RAABBVI', 'DADVI', 'LRVB', 'SADVI', 'SADVI_FR', 'LRVB_Doubling']:
        subdir_lookup = {
            'RAABBVI': 'info',
            'DADVI': 'dadvi_info',
            'LRVB': 'lrvb_info',
            'SADVI': 'info',
            'SADVI_FR': 'info',
            'LRVB_Doubling': 'lrvb_info'
        }
        subdir = subdir_lookup[method]
        moment_df["info_path"] = (
            moment_df["draw_dict_path"]
            .str.replace("draw_dicts", subdir)
            .str.replace(".npz", ".pkl", regex=False)
        )

        moment_df['metadata'] = moment_df['info_path'].apply(load_pickle_safely)
        moment_df['runtime'] = moment_df['metadata'].apply(lambda x: x['runtime'])

        if method.startswith('SADVI'):
            moment_df['converged'] = moment_df['metadata'].apply(check_convergence_sadvi)
        elif method == 'RAABBVI':
            moment_df['converged'] = moment_df['metadata'].apply(check_convergence_raabbvi)

    else:
        # It's NUTS; get runtime:
        moment_df['runtime_path'] = (
            moment_df['draw_dict_path']
            .str.replace('draw_dicts', 'runtimes')
            .str.replace('.npz', '.csv', regex=False)
        )
        moment_df['runtime'] = (
            moment_df['runtime_path']
            .apply(lambda x: pd.read_csv(x)['0'].iloc[0])
        )

        # TODO: get rhat

    return moment_df

from os.path import join

base_folder = '../blade_runs/'

folder_method_list = (
    (join(base_folder, "nuts_results/"), 'NUTS'),
    (join(base_folder, "dadvi_results/"), 'DADVI'),
    (join(base_folder, "lrvb_results/"), 'LRVB'),
    (join(base_folder, "raabbvi_results/"), 'RAABBVI'),
    (join(base_folder, "sadvi_results/"), 'SADVI'),
    (join(base_folder, "sfullrank_advi_results/"), 'SADVI_FR'),
    (join(base_folder, 'lrvb_doubling_results'), 'LRVB_Doubling')
)

all_results = dict()

for cur_folder, cur_method in folder_method_list:

    print(cur_method, cur_folder)

    data = load_moment_df(cur_folder)

    data = add_metadata(data, cur_method)

    all_results[cur_method] = data


NUTS ../blade_runs/nuts_results/
DADVI ../blade_runs/dadvi_results/




LRVB ../blade_runs/lrvb_results/
RAABBVI ../blade_runs/raabbvi_results/
SADVI ../blade_runs/sadvi_results/
SADVI_FR ../blade_runs/sfullrank_advi_results/
LRVB_Doubling ../blade_runs/lrvb_doubling_results


In [None]:
def add_deviation_stats(model_df, reference_df):

    together = model_df.merge(
        reference_df, on="model_name", suffixes=("_model", "_reference")
    )

    together["mean_deviations"] = together.apply(
        lambda x: compute_z_score_mean(
            x["means_model"], x["means_reference"], x["sds_reference"]
        ),
        axis=1,
    )

    together["sd_deviations"] = together.apply(
        lambda x: compute_relative_error_sd(x["sds_model"], x["sds_reference"]), axis=1
    )

    together["var_names"] = together["means_reference"].apply(
        lambda x: sorted(list(x.keys()))
    )


    # Add these to the model stats
    cols_to_keep = [
        "model_name",
        "mean_deviations",
        "sd_deviations",
        "var_names",
    ]

    new_stats = together[cols_to_keep]

    return model_df.merge(new_stats, on='model_name', how='left')


def add_derived_stats(model_df):

    model_df["mean_deviations_flat"] = model_df.apply(
        lambda x: flatten_dict(x["mean_deviations"], x["var_names"]), axis=1
    )

    model_df["sd_deviations_flat"] = model_df.apply(
        lambda x: flatten_dict(x["sd_deviations"], x["var_names"]), axis=1
    )

    model_df['mean_rms'] = model_df['mean_deviations_flat'].apply(lambda x: np.sqrt(np.mean(x**2)))
    model_df['sd_rms'] = model_df['sd_deviations_flat'].apply(lambda x: np.sqrt(np.mean(x**2)))

    return model_df


In [None]:
all_results['SADVI']

In [None]:
raabbvi_maxiter = 19900

method_1 = 'LRVB_Doubling'
method_2 = 'RAABBVI'

method_1_df = add_deviation_stats(all_results[method_1], all_results['NUTS']).dropna()
method_1_df = add_derived_stats(method_1_df)

method_2_df = add_deviation_stats(all_results[method_2], all_results['NUTS']).dropna()
method_2_df = add_derived_stats(method_2_df)


In [None]:
all_results['SADVI']['converged'].mean()

In [None]:
comparison = method_1_df.merge(method_2_df, on='model_name', 
                                         suffixes=(f'_{method_1}', f'_{method_2}'))

comparison

In [None]:
raabvi_with_deviations.iloc[0]['metadata'].keys()

In [None]:
import matplotlib.pyplot as plt

f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'mean_rms_{method_1}'].min(), comparison[f'mean_rms_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'mean_rms_{method_1}'], comparison[f'mean_rms_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'mean_rms_{method_1}'), getattr(row, f'mean_rms_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'RMSE mean scaled by posterior sd, {method_1}')
ax.set_ylabel(f'RMSE mean scaled by posterior sd, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('./mean_comparison.png', dpi=300)


In [None]:
f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'sd_rms_{method_1}'].min(), comparison[f'sd_rms_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'sd_rms_{method_1}'], comparison[f'sd_rms_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'sd_rms_{method_1}'), getattr(row, f'sd_rms_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'RMSE sd scaled by posterior sd, {method_1}')
ax.set_ylabel(f'RMSE sd scaled by posterior sd, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('./sd_comparison.png', dpi=300)

In [None]:

f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'runtime_{method_1}'].min(), comparison[f'runtime_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'runtime_{method_1}'], comparison[f'runtime_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'runtime_{method_1}'), getattr(row, f'runtime_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'Runtime, {method_1}')
ax.set_ylabel(f'Runtime, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('runtime_comparison.png', dpi=300)

In [None]:
comparison.head()

In [None]:
all_results['LRVB_Doubling']['M'] = all_results['LRVB_Doubling']['metadata'].apply(lambda x: x['M'])

In [None]:
all_results['LRVB_Doubling'][['model_name', 'runtime', 'M']].sort_values('M', ascending=False).head(20)

In [None]:
all_results['LRVB'][['model_name', 'runtime']].sort_values('runtime')