Load results from the `blade_runs` directory and save in a tidy format for R

In [1]:
from os.path import join
import pandas as pd
import numpy as np

In [2]:
from load_and_tidy_lib import VALID_METHODS, GetMetadataDataframe, GetMethodDataframe

In [3]:
base_folder = '/home/rgiordan/Documents/git_repos/DADVI/dadvi-experiments'
input_folder = join(base_folder, 'comparison/blade_runs/')
output_folder = input_folder

folder_method_list = (
    (join(input_folder, "nuts_results/"), 'NUTS'),
    (join(input_folder, "dadvi_results/"), 'DADVI'),
    (join(input_folder, "lrvb_results/"), 'LRVB'),
    (join(input_folder, "raabbvi_results/"), 'RAABBVI'),
    (join(input_folder, "sadvi_results/"), 'SADVI'),
    (join(input_folder, "sfullrank_advi_results/"), 'SADVI_FR'),
    (join(input_folder, 'lrvb_doubling_results'), 'LRVB_Doubling')
)


In [4]:
posterior_dfs = []
metadata_dfs = []
for folder, method in folder_method_list:
    print(f'Loading {method}')
    posterior_dfs.append(GetMethodDataframe(folder, method))
    metadata_dfs.append(GetMetadataDataframe(folder, method))
    
posterior_df = pd.concat(posterior_dfs)
metadata_df = pd.concat(metadata_dfs)

Loading NUTS
Loading DADVI




Loading LRVB
Loading RAABBVI
Loading SADVI
Loading SADVI_FR
Loading LRVB_Doubling


  metadata_df = pd.concat(metadata_dfs)


In [5]:
posterior_df.to_csv(join(output_folder, 'posteriors_tidy.csv'), index=False)
metadata_df.to_csv(join(output_folder, 'metadata_tidy.csv'), index=False)

# Explore the contents of the metadata.  

Maybe we want to save additional information.

In [6]:
raw_metadata = {}
for folder, method in folder_method_list:
    print(f'Loading {method}')
    raw_metadata[method] = GetMetadataDataframe(folder, method, return_raw_metadata=True) 


Loading NUTS
Loading DADVI
Loading LRVB
Loading RAABBVI
Loading SADVI
Loading SADVI_FR
Loading LRVB_Doubling


In [None]:

def GetObjectiveTrace(method, metadata):
    missing_value = float('NaN')
    if method == 'NUTS':
        # Doesn't make sense for NUTS
        return missing_value
    elif method == 'RAABBVI':
        obj_hist = np.array(metadata['kl_hist'])
        step_hist = np.array(metadata['kl_hist_i'])
    elif method == 'DADVI':
        opt_sequence = metadata['opt_sequence']
        obj_hist = np.array(metadata['kl_hist'])
        step_hist = np.array([ o['val_and_grad_calls'] + 
                               o['hvp_calls'] for o in opt_sequence ])
        assert(len(call_hist) == len(obj_hist))
    elif method == 'LRVB':
        # Doesn't make sense for LRVB
        return missing_value
    elif method == 'SADVI':
        # TODO: Save KL traces for SADVI
        return missing_value
    elif method == 'SADVI_FR':
        # TODO: Save KL traces for SADVI
        return missing_value
    elif method == 'LRVB_Doubling':
        # TODO: make sure this makese sense
        return missing_value
    else:
        print(f'Invalid method {method}\n')
        assert(False)


In [7]:
for k,v in raw_metadata.items():
    print('=======================================\n', k, ':')
    print(v[0])
    print('\n')

 NUTS :
{'runtime': 13.11353850364685}


 DADVI :
{'opt_result': {'opt_result':      fun: 693.8896547908593
     jac: array([-8.21633863e-07, -2.88874278e-06,  9.97489176e-08, -6.05746638e-06,
        1.20624896e-05, -1.58890746e-05, -2.77626218e-05, -2.63565063e-05])
 message: 'Optimization terminated successfully.'
    nfev: 71
    nhev: 254
     nit: 72
    njev: 64
  status: 0
 success: True
       x: DeviceArray([97.48527   ,  2.0894773 ,  0.45845976,  4.6221404 ,
             -0.45019984, -2.714926  , -4.087576  , -0.10485008],            dtype=float32), 'evaluation_count': {'n_hvp_calls': 253, 'n_val_and_grad_calls': 71}}, 'fixed_draws': array([[-4.16757847e-01, -5.62668272e-02, -2.13619610e+00,
         1.64027081e+00],
       [-1.79343559e+00, -8.41747366e-01,  5.02881417e-01,
        -1.24528809e+00],
       [-1.05795222e+00, -9.09007615e-01,  5.51454045e-01,
         2.29220801e+00],
       [ 4.15393930e-02, -1.11792545e+00,  5.39058321e-01,
        -5.96159700e-01],
       