In [47]:
# Let's first work out which parameter we're actually interested in.
from dadvi.pymc.models.potus import get_potus_model

In [48]:
potus_model = get_potus_model('../data/potus_data.json')

In [49]:
import numpy as np
import json

data = json.load(open('../data/potus_data.json'))

np_data = {x: np.squeeze(np.array(y)) for x, y in data.items()}

In [50]:
from dadvi.pymc.pymc_to_jax import get_jax_functions_from_pymc
from dadvi.jax import build_dadvi_funs

In [51]:
jax_funs = get_jax_functions_from_pymc(potus_model)

In [52]:
jax_funs

{'log_posterior_fun': <function dadvi.pymc.pymc_to_jax.get_jax_functions_from_pymc.<locals>.flat_log_post_fun(flat_params)>,
 'unflatten_fun': <function jax._src.flatten_util.ravel_pytree.<locals>.<lambda>(flat)>,
 'n_params': 15098}

In [53]:
# This isn't super neat, but we should be able to reconstruct the parameters from any old vector:

In [54]:
import numpy as np

In [55]:
z = np.arange(jax_funs['n_params'])

In [56]:
param_dict = jax_funs['unflatten_fun'](z)

In [57]:
param_shapes = {x: y.shape for x, y in param_dict.items()}

In [58]:
param_shapes

{'mu_e_bias': (),
 'raw_e_bias': (254,),
 'raw_measure_noise_national': (361,),
 'raw_measure_noise_state': (1258,),
 'raw_mu_b': (51, 254),
 'raw_mu_b_T': (51,),
 'raw_mu_c': (161,),
 'raw_mu_m': (3,),
 'raw_mu_pop': (3,),
 'raw_polling_bias': (51,),
 'rho_e_bias_interval__': ()}

In [59]:
# Maybe I can use `raw_mu_b_T`, which is possibly the final vote share except for a known matrix multiplication:
"""
         mu_b_final = (
            pm.math.dot(cholesky_ss_cov_mu_b_T, raw_mu_b_T) + np_data["mu_b_prior"]
        )
"""
# I guess if I can get its covariance, I should be able to reconstruct things.
# So, next steps:
# 1. Work out the indices corresponding to `raw_mu_b_T`
# 2. Use CG to calculate the relevant covariances 
# 3. Am I screwed because I can only get marginal frequentist variances?
# 4. Maybe consult with Ryan

'\n         mu_b_final = (\n            pm.math.dot(cholesky_ss_cov_mu_b_T, raw_mu_b_T) + np_data["mu_b_prior"]\n        )\n'

In [60]:
indices = param_dict['raw_mu_b_T']

In [61]:
sample_index = indices[0]

In [62]:
from dadvi.core import compute_hessian_inv_column, compute_frequentist_covariance_using_score_mat
from dadvi.core import compute_score_matrix

In [139]:
# Load the optimal parameters
import pickle
from glob import glob

all_runs = glob('../potus_coverage/64/*/*/*/*.pkl')

reference_run = all_runs[1]
other_runs = all_runs[2:]

results = pickle.load(open(reference_run, 'rb'))

In [140]:
opt_params = results['opt_result']['opt_result'].x
z = results['fixed_draws']
dadvi_funs = build_dadvi_funs(jax_funs['log_posterior_fun'])

In [141]:
from tqdm import tqdm

In [142]:
score_mat = compute_score_matrix(opt_params, dadvi_funs.kl_est_and_grad_fun, z)

In [143]:
opt_means = opt_params[:opt_params.shape[0]//2]

In [144]:
mean_dict = jax_funs['unflatten_fun'](opt_means)

In [145]:
national_cov_matrix_error_sd = np.sqrt(
        np.squeeze(
            np_data["state_weights"].reshape(1, -1)
            @ (np_data["state_covariance_0"] @ np_data["state_weights"].reshape(-1, 1))
        )
    )
ss_cov_mu_b_T = (
    np_data["state_covariance_0"]
    * (np_data["mu_b_T_scale"] / national_cov_matrix_error_sd) ** 2
)
cholesky_ss_cov_mu_b_T = np.linalg.cholesky(ss_cov_mu_b_T)

rel_means = mean_dict['raw_mu_b_T']

In [146]:
mean_shares = cholesky_ss_cov_mu_b_T @ rel_means + np_data["mu_b_prior"]

In [147]:
national_average = mean_shares @ np_data['state_weights']

In [148]:
np_data['state_weights']

array([0.0023, 0.0158, 0.0081, 0.0183, 0.1025, 0.0207, 0.0118, 0.0024,
       0.0033, 0.0677, 0.0305, 0.0034, 0.0121, 0.0051, 0.0395, 0.02  ,
       0.0088, 0.0136, 0.0153, 0.0246, 0.0209, 0.0054, 0.0358, 0.0226,
       0.0209, 0.0097, 0.0037, 0.0353, 0.0026, 0.0061, 0.0054, 0.0278,
       0.0059, 0.0081, 0.0541, 0.0421, 0.0103, 0.014 , 0.0434, 0.0034,
       0.0154, 0.0028, 0.019 , 0.0644, 0.0081, 0.0299, 0.0023, 0.0246,
       0.0233, 0.005 , 0.0019])

In [149]:
national_average

DeviceArray(0.05745102, dtype=float64)

In [150]:
# There might be a bit more but this is a good start.
# Compute the gradient required now.

In [151]:
from jax.scipy.special import expit
from jax import grad

# Check whether this is the correct way to get the national pct!!
def compute_final_vote_share(full_var_params):

    # I think we only need the means:
    opt_means = full_var_params[:full_var_params.shape[0] // 2]

    mean_dict = jax_funs['unflatten_fun'](opt_means)

    rel_means = mean_dict['raw_mu_b_T']

    mean_shares = cholesky_ss_cov_mu_b_T @ rel_means + np_data["mu_b_prior"]

    return expit(mean_shares) @ np_data['state_weights']

In [152]:
rel_grad = grad(compute_final_vote_share)(opt_params)

In [153]:
rel_grad.min(), rel_grad.max()

(DeviceArray(-0.00023501, dtype=float64),
 DeviceArray(0.01753882, dtype=float64))

In [154]:
from dadvi.utils import cg_using_fun_scipy

In [155]:
# TODO Preconditioner
rel_hvp = lambda x: dadvi_funs.kl_est_hvp_fun(opt_params, z, x)
cg_result = cg_using_fun_scipy(rel_hvp, rel_grad, preconditioner=None)

In [156]:
h_inv_g = cg_result[0]

score_mat.shape

(30, 30196)

In [157]:
score_mat_means = score_mat.mean(axis=0, keepdims=True)
centred_score_mat = score_mat - score_mat_means

In [158]:
vec = centred_score_mat @ h_inv_g
M = score_mat.shape[0]
freq_sd = np.sqrt((vec.T @ vec) / (M * (M - 1)))

In [159]:
# Load the runs
from glob import glob

In [160]:
others_loaded = [pickle.load(open(x, 'rb')) for x in other_runs]

In [161]:
all_zs = list()

for other_run in others_loaded:

    other_opt_params = other_run['opt_result']['opt_result'].x

    other_res = compute_final_vote_share(other_opt_params)
    prev_res = compute_final_vote_share(opt_params)

    cur_z = (other_res - prev_res) / (np.sqrt(2) * freq_sd)

    all_zs.append(cur_z)

In [162]:
np.array(all_zs)

array([ 0.9784361 ,  0.53134087,  0.1685255 ,  0.79292589,  0.42817879,
       -1.45472035,  1.31158116, -0.9623922 ])

In [163]:
# What do they look like marginally over all of them
# Check pairs more systematically
# If not good, check better score matrix.
# Warm starts for this as well
compute_final_vote_share(opt_params)

DeviceArray(0.51371472, dtype=float64)

In [164]:
freq_sd

0.0003797790492909348