In [1]:
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import inspect
import os
import sys
import time

import paragami
import vittles

from copy import deepcopy

import bnpregcluster_runjingdev.regression_mixture_lib as gmm_lib
import bnpregcluster_runjingdev.posterior_quantities_lib as post_lib

np.random.seed(42) # nothing special about this seed (we hope)!

In [21]:
class Args():
    def __init__(self):
        pass
    
args = Args()
args.fit_directory = '/home/rgiordan/Documents/git_repos/BNP_sensitivity/RegressionClustering/fits/cluster'
args.refit_filename = os.path.join(args.fit_directory,
    'transformed_gene_regression_df4_degree3_genes700_' +
    'num_components40_inflate1.0_shrunkTrue_alphascale1.0_' +
    'functionalTrue_logphiexpit_refit.npz')

def set_directory(filename):
    # If the fit_directory argument is set, replace a datafile's directory
    # with the specified fit_directory and return the new location.
    if args.fit_directory is None:
        return filename
    else:
        _, file_only_name = os.path.split(filename)
        return os.path.join(args.fit_directory, file_only_name)


In [25]:
with np.load(args.refit_filename) as infile:
    initial_fitfile = set_directory(str(infile['input_filename']))
    gmm_params_pattern = paragami.get_pattern_from_json(
        str(infile['gmm_params_pattern_json']))
    reopt_gmm_params = gmm_params_pattern.fold(
        infile['reopt_gmm_params_flat'], free=False)
    prior_params_pattern = paragami.get_pattern_from_json(
        str(infile['reopt_prior_params_pattern_json']))
    reopt_prior_params = prior_params_pattern.fold(
        infile['reopt_prior_params_flat'], free=False)
    reopt_time = infile['reopt_time']
    alpha_scale = infile['alpha_scale']
    functional = bool(infile['functional'])
    log_phi_desc = str(infile['log_phi_desc'])
    
assert functional

if not os.path.isfile(initial_fitfile):
    raise ValueError('Initial fit {} not found'.format(initial_fitfile))

with np.load(initial_fitfile) as infile:
    gmm_params_pattern = paragami.get_pattern_from_json(
        str(infile['gmm_params_pattern_json']))
    opt_gmm_params = gmm_params_pattern.fold(
        infile['opt_gmm_params_flat'], free=False)
    prior_params_pattern = paragami.get_pattern_from_json(
        str(infile['prior_params_pattern_json']))
    prior_params = prior_params_pattern.fold(
        infile['prior_params_flat'], free=False)
    kl_hess = infile['kl_hess']
    df = infile['df']
    degree = infile['degree']
    datafile = set_directory(str(infile['datafile']))
    num_components = int(infile['num_components'])

if not os.path.isfile(datafile):
    raise ValueError('Datafile {} not found'.format(datafile))

reg_params = dict()
with np.load(datafile) as infile:
    reg_params['beta_mean'] = infile['transformed_beta_mean']
    reg_params['beta_info'] = infile['transformed_beta_info']
    inflate_cov = infile.get('inflate_cov', 0)
    eb_shrunk = infile.get('eb_shrunk', False)

In [26]:
gmm = gmm_lib.GMM(num_components, prior_params, reg_params)

log_phi = gmm_lib.get_log_phi(log_phi_desc)

prior_pert = gmm_lib.PriorPerturbation(log_phi, gmm.gh_loc, gmm.gh_weights)
gmm.set_perturbation_fun(prior_pert.get_e_log_perturbation)
prior_pert.set_epsilon(0.0) # We evaluate derivatives at epsilon = 0

In [28]:
epsilon = np.array([alpha_scale])
print(epsilon)

[1.]


In [29]:
n_samples = 10000
threshold = 0
predictive = False

get_posterior_quantity = post_lib.get_posterior_quantity_function(
    predictive, gmm, n_samples, threshold)

get_posterior_quantity(opt_gmm_params)

39.9698

In [30]:
def get_kl_from_vb_epsilon(params, epsilon):
    prior_pert.set_epsilon(epsilon)
    return gmm.get_params_kl(params)

get_kl_from_vb_free_prior_free = \
    paragami.FlattenFunctionInput(
        original_fun=get_kl_from_vb_epsilon,
        patterns = gmm.gmm_params_pattern,
        free = True,
        argnums = 0)

opt_params0 = gmm.gmm_params_pattern.flatten(opt_gmm_params, free=True)
get_kl_from_vb_free_prior_free(opt_params0, epsilon)

# get_perturbation_free = \
#     paragami.FlattenFunctionInput(original_fun=
#         prior_pert.get_e_log_perturbation_epsilon,
#         patterns = gmm.gmm_params_pattern,
#         free = True,
#         argnums = 0)
# get_perturbation_free(opt_params0, epsilon)

-44983.793800960986

In [31]:
# Sanity check
print(
    prior_pert.get_e_log_perturbation_epsilon(opt_gmm_params, 0.0),
    prior_pert.get_e_log_perturbation_epsilon(opt_gmm_params, 0.5),
    prior_pert.get_e_log_perturbation_epsilon(opt_gmm_params, 1.0)
)

-0.0 -1.943623562870231 -3.887247125740462


In [32]:
taylor_order = 1
t0 = time.time()
vb_sens = \
    vittles.ParametricSensitivityTaylorExpansion(
        objective_function =           get_kl_from_vb_free_prior_free,
        input_val0 =                   gmm.gmm_params_pattern.flatten(opt_gmm_params, free=True),
        hyper_val0 =                   np.array([0.0]),
        order =                        taylor_order,
        hess0 =                        kl_hess)

print('linear response Hessian time: {:.03f} secs'.format(time.time() - t0))

linear response Hessian time: 0.008 secs


In [33]:
predict_gmm_params = \
    paragami.FoldFunctionOutput(
        original_fun=vb_sens.evaluate_taylor_series,
        patterns=gmm.gmm_params_pattern,
        free=True,
        retnums=[0])

lr_time = time.time()
pred_gmm_params = predict_gmm_params(epsilon)
lr_time = lr_time - time.time()

e_num0 = get_posterior_quantity(opt_gmm_params)
e_num1 = get_posterior_quantity(reopt_gmm_params)
e_num_pred = get_posterior_quantity(pred_gmm_params)

print('Orig e: \t{}\nRefit e:\t{}\nPred e:\t\t{}\nActual diff:\t{:0.5}\nPred diff:\t{:0.5}'.format(
    e_num0, e_num1, e_num_pred,
    e_num1 - e_num0,
    e_num_pred - e_num0))

Orig e: 	39.9698
Refit e:	39.9696
Pred e:		39.9696
Actual diff:	-0.0002
Pred diff:	-0.0002
