In [None]:
from glob import glob
from os.path import join
import pandas as pd
import numpy as np

from load_results_lib import load_moment_df, add_metadata
from load_results_lib import add_deviation_stats, add_derived_stats
from load_results_lib import VALID_METHODS

In [4]:
base_folder = '../blade_runs/'

folder_method_list = (
    (join(base_folder, "nuts_results/"), 'NUTS'),
    (join(base_folder, "dadvi_results/"), 'DADVI'),
    (join(base_folder, "lrvb_results/"), 'LRVB'),
    (join(base_folder, "raabbvi_results/"), 'RAABBVI'),
    (join(base_folder, "sadvi_results/"), 'SADVI'),
    (join(base_folder, "sfullrank_advi_results/"), 'SADVI_FR'),
    (join(base_folder, 'lrvb_doubling_results'), 'LRVB_Doubling')
)

all_results = dict()

for cur_folder, cur_method in folder_method_list:

    print(cur_method, cur_folder)

    data = load_moment_df(cur_folder)

    data = add_metadata(data, cur_method)

    all_results[cur_method] = data

print('Done!')

NUTS ../blade_runs/nuts_results/
DADVI ../blade_runs/dadvi_results/




LRVB ../blade_runs/lrvb_results/
RAABBVI ../blade_runs/raabbvi_results/
SADVI ../blade_runs/sadvi_results/
SADVI_FR ../blade_runs/sfullrank_advi_results/
LRVB_Doubling ../blade_runs/lrvb_doubling_results
Done!


In [9]:
assert(len(set(all_results.keys()).symmetric_difference(VALID_METHODS)) == 0)

In [42]:
method_models = {x: all_results[x]['model_name'].tolist() for x in VALID_METHODS}
all_models = set().union(*[ v for k, v in method_models.items() ])

for method in VALID_METHODS:
    print(method)
    missing_models = all_models.difference(method_models[method])
    if len(missing_models) > 0:
        print('Missing models:')
        print('\n'.join(missing_models))
        print('\n')

NUTS
Missing models:
election88_full
electric_1c
mesquite_vash
hiv
electric_one_pred
hiv_inter
electric_1b
mesquite_vas
microcredit
electric_multi_preds
electric_1a
electric


RAABBVI
Missing models:
electric_1c
mesquite_vash
hiv
electric_one_pred
hiv_inter
tennis
potus
electric_1b
electric_multi_preds
electric_1a
electric


DADVI
Missing models:
electric_1c
hiv
electric_one_pred
hiv_inter
electric_1b
electric_multi_preds
electric_1a
electric


LRVB
Missing models:
election88_full
electric_1c
hiv
electric_one_pred
hiv_inter
potus
electric_1b
electric_multi_preds
electric_1a
electric


SADVI
Missing models:
electric_1c
hiv
electric_one_pred
hiv_inter
test
electric_1b
electric_multi_preds
electric_1a
electric


SADVI_FR
Missing models:
electric_1c
hiv
electric_one_pred
hiv_inter
potus
electric_1b
electric_multi_preds
electric_1a
electric


LRVB_Doubling
Missing models:
potus




In [47]:
# Indeed
any(all_results['NUTS']['model_name'] == 'election88_full')

False

In [None]:
method_1 = 'LRVB_Doubling'
all_results[method_1]["means"][0], \
    all_results[method_1]["sds"][0]

In [None]:
all_results['NUTS'][all_results['NUTS'].model_name == 'wells_daae_c']

In [None]:
# This is missing
all_results['NUTS'][all_results['NUTS'].model_name == 'electric_multi_preds']

In [None]:
raabbvi_maxiter = 19900

method_1 = 'LRVB_Doubling'
method_2 = 'RAABBVI'

#method_1_df = add_deviation_stats(all_results[method_1], all_results['NUTS']).dropna()
method_1_df = add_deviation_stats(all_results[method_1], all_results['NUTS'])

print(method_1_df)

In [None]:
method_1_df['means'][0]

In [None]:
method_1_df = add_derived_stats(method_1_df)

method_2_df = add_deviation_stats(all_results[method_2], all_results['NUTS']).dropna()
method_2_df = add_derived_stats(method_2_df)


In [None]:
print(all_results['SADVI']['converged'].mean())
print(all_results['RAABBVI']['converged'].mean())

In [None]:
all_results['SADVI']

In [None]:
comparison = method_1_df.merge(method_2_df, on='model_name', 
                                         suffixes=(f'_{method_1}', f'_{method_2}'))

comparison

In [None]:
#raabvi_with_deviations.iloc[0]['metadata'].keys()

In [None]:
import matplotlib.pyplot as plt

f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'mean_rms_{method_1}'].min(), comparison[f'mean_rms_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'mean_rms_{method_1}'], comparison[f'mean_rms_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'mean_rms_{method_1}'), getattr(row, f'mean_rms_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'RMSE mean scaled by posterior sd, {method_1}')
ax.set_ylabel(f'RMSE mean scaled by posterior sd, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('./mean_comparison.png', dpi=300)


In [None]:
f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'sd_rms_{method_1}'].min(), comparison[f'sd_rms_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'sd_rms_{method_1}'], comparison[f'sd_rms_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'sd_rms_{method_1}'), getattr(row, f'sd_rms_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'RMSE sd scaled by posterior sd, {method_1}')
ax.set_ylabel(f'RMSE sd scaled by posterior sd, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('./sd_comparison.png', dpi=300)

In [None]:

f, ax = plt.subplots(1, 1)

xmin, xmax = [comparison[f'runtime_{method_1}'].min(), comparison[f'runtime_{method_1}'].max()]
# ax.scatter(comparison['mean_rms_raabbvi'], comparison['mean_rms_lrvb'], c=comparison['converged'])
ax.scatter(comparison[f'runtime_{method_1}'], comparison[f'runtime_{method_2}'])
ax.plot([xmin, xmax], [xmin, xmax])

for row in comparison.itertuples():
    ax.annotate(row.model_name, (getattr(row, f'runtime_{method_1}'), getattr(row, f'runtime_{method_2}')))

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel(f'Runtime, {method_1}')
ax.set_ylabel(f'Runtime, {method_2}')

ax.grid(alpha=0.5, linestyle='--')

f.set_size_inches(12, 8)
f.tight_layout()

# plt.savefig('runtime_comparison.png', dpi=300)

In [None]:
comparison.head()

In [None]:
all_results['LRVB_Doubling']['M'] = all_results['LRVB_Doubling']['metadata'].apply(lambda x: x['M'])

In [None]:
all_results['LRVB_Doubling'][['model_name', 'runtime', 'M']].sort_values('M', ascending=False).head(20)

In [None]:
all_results['LRVB'][['model_name', 'runtime']].sort_values('runtime')