In [8]:
import sys
import os
os.chdir("/Users/llaurabat/Dropbox/BGSE_work/LJRZH_graphs/graphical-regression-with-networks/numpyro/FINAL_ALL")
sys.path.append("functions")

sim_data_path = './data/sim_data/'
data_save_path = './data/'

In [9]:
import models
import my_utils

In [10]:
# imports
import matplotlib.pyplot as plt
import pickle
from numpyro.util import enable_x64
import time
from datetime import timedelta
import jax
import numpyro
# numpyro.set_platform('gpu')
print(jax.lib.xla_bridge.get_backend().platform)

from jax import random, vmap
import jax.numpy as jnp

from numpyro.infer import MCMC, NUTS
from jax.random import PRNGKey as Key
from numpyro.infer import init_to_feasible, init_to_value
from numpyro.handlers import condition, block


cpu


In [11]:
enable_x64(use_x64=True)
print("Is 64 precision enabled?:", jax.config.jax_enable_x64)

Is 64 precision enabled?: True


In [12]:
n_sims = 50
p = 10
n = 2000
n_cut = 100
TP_thresh = 3

# Run full model with MCMC, reparametrisation and is_dense=False

In [13]:
# params
n_warmup = 1000
n_samples = 5000

mu_m=0.
mu_s=1.

my_model = models.golazo_ss_repr_etaRepr
is_dense=False

estimates_print = ["w_slab", "mean_slab", "scale_slab"]

In [14]:
mu_fixed=jnp.zeros((p,))
scale_spike_fixed=0.003

In [15]:
# 85semidep params

eta0_0_m_85SEMI=0.
eta0_0_s_85SEMI=0.126
eta0_coefs_m_85SEMI=0.
eta0_coefs_s_85SEMI=0.126

eta1_0_m_85SEMI=-2.197
eta1_0_s_85SEMI=0.4
eta1_coefs_m_85SEMI=0.
eta1_coefs_s_85SEMI=0.4

eta2_0_m_85SEMI=-2.444
eta2_0_s_85SEMI=1.944
eta2_coefs_m_85SEMI=0.
eta2_coefs_s_85SEMI=1.944


In [16]:
# init strategy
rho_init = jnp.diag(jnp.ones((p,)))
mu_init = jnp.zeros((p,))
sqrt_diag_init = jnp.ones((p,))

my_init_strategy_85SEMI = init_to_value(values={'rho':rho_init, 
                                         'mu':mu_init, 
                                         'sqrt_diag':sqrt_diag_init, 
                                         'tilde_eta0_0':0.,
                                         'tilde_eta1_0':0.,
                                        'tilde_eta2_0':0.,
                                        'tilde_eta0_coefs':jnp.array([0.]),
                                        'tilde_eta1_coefs':jnp.array([0.]),
                                        'tilde_eta2_coefs':jnp.array([0.]),})


In [17]:
diagnostics_all = {}
for s in range(n_sims):
    print('--------------------------------------------------------------------------------')
    print(f" Simulation number: {s} \n Dimensions: p = {p}, n = {n_cut} \n Run Network-SS with A-85SEMIDEP")
    with open(sim_data_path + f'sim{s}_p{p}_n{n}.sav', 'rb') as fr:
        sim_res = pickle.load(fr)

    A_list = [jnp.array(sim_res["A_scaled_semi_dep85"])]
    my_model_args = {"A_list":A_list, "eta0_0_m":eta0_0_m_85SEMI, "eta0_0_s":eta0_0_s_85SEMI, 
                 "eta0_coefs_m":eta0_coefs_m_85SEMI, "eta0_coefs_s":eta0_coefs_s_85SEMI,
                 "eta1_0_m":eta1_0_m_85SEMI, "eta1_0_s":eta1_0_s_85SEMI, 
                 "eta1_coefs_m":eta1_coefs_m_85SEMI, "eta1_coefs_s":eta1_coefs_s_85SEMI,
                 "eta2_0_m":eta2_0_m_85SEMI, "eta2_0_s":eta2_0_s_85SEMI, 
                 "eta2_coefs_m":eta2_coefs_m_85SEMI, "eta2_coefs_s":eta2_coefs_s_85SEMI,
                 "mu_m":mu_m, "mu_s":mu_s} 
    
    # select data 
    Y = jnp.array(sim_res['Y'])
    Y = Y[:n_cut,:]
    theta_true = jnp.array(sim_res['theta_true'])
    tril_idx = jnp.tril_indices(n=p, k=-1, m=p)
    nonzero_true = (jnp.abs(theta_true[tril_idx]) != 0.)

    # set model
    fixed_params_dict = {"scale_spike":scale_spike_fixed, 
                         "mu":mu_fixed}
    blocked_params_list = ["scale_spike", "mu"]
    my_model_run = block(condition(my_model, fixed_params_dict), 
                         hide=blocked_params_list)


    nuts_kernel = NUTS(my_model_run, init_strategy=my_init_strategy_85SEMI, 
                       dense_mass=is_dense)
    
    
    # run model and time up
    start_time = time.time()
    mcmc = MCMC(nuts_kernel, num_warmup=n_warmup, num_samples=n_samples)
    mcmc.run(rng_key = Key(s+44), Y=Y, **my_model_args,
            extra_fields=('potential_energy','accept_prob', 
                          'num_steps', 'adapt_state'))

    end_time = time.time()
    seconds_elapsed = end_time - start_time

    print(str(timedelta(seconds=seconds_elapsed)))

    # save samples
    res_all_samples = mcmc.get_samples()

    # record diagnostics
    params = ['eta0_0', 'eta0_coefs', 'eta0_0', 'eta0_coefs',
             'eta0_0', 'eta0_coefs', 'rho_tilde', 'rho_lt', 'sqrt_diag']

    diagnostics_dict = {'ESS':{}, 'r_hat':{}}

    for par in params:
        ESS = numpyro.diagnostics.summary(jnp.expand_dims(res_all_samples[par],
                                                          0))['Param:0']['n_eff']

        diagnostics_dict['ESS'][par] = ESS

        r_hat = numpyro.diagnostics.summary(jnp.expand_dims(res_all_samples[par],
                                                            0))['Param:0']['r_hat']

        diagnostics_dict['r_hat'][par] = r_hat

    diagnostics_dict.update({'potential_energy':mcmc.get_extra_fields()['potential_energy']})
    diagnostics_dict.update({'seconds_elapsed':seconds_elapsed})


    # save
    with open(data_save_path + f'diagnostics_ss_A85semi_{s}_p{p}_n{n_cut}.sav' , 'wb') as f:
        pickle.dump((diagnostics_dict), f)

        
    diagnostics_all[s] = diagnostics_dict
with open(data_save_path + f'diagnostics_ss_A85semi_all_p{p}_n{n_cut}.sav' , 'wb') as f:
    pickle.dump((diagnostics_all), f)

--------------------------------------------------------------------------------
 Simulation number: 0 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 6/6 [00:05<00:00,  1.04it/s, 1 steps of size 2.34e+00. acc. prob=0.00]
  0%|          | 0/6 [00:00<?, ?it/s]

0:00:11.059661
--------------------------------------------------------------------------------
 Simulation number: 1 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 6/6 [00:06<00:00,  1.02s/it, 1 steps of size 2.34e+00. acc. prob=0.00]


0:00:06.864805


# Average diagnostics across simulations

In [18]:
# load diagnostics_all
with open(data_save_path + f'diagnostics_ss_A85semi_all_p{p}_n{n_cut}.sav', 'rb') as fr:
    diagnostics_all = pickle.load(fr)

In [19]:
params = ['eta0_0', 'eta0_coefs', 'eta0_0', 'eta0_coefs',
         'eta0_0', 'eta0_coefs', 'rho_tilde', 'rho_lt', 'sqrt_diag']


In [20]:
avg_diagnostics = {}
for d in ['potential_energy', 'seconds_elapsed']:
    avg_d = []
    for s in range(n_sims):
        avg_d.append(diagnostics_all[s][d])

    avg_d = jnp.array(avg_d).mean(0)
    avg_diagnostics[d] = avg_d
    

In [21]:
avg_diagnostics.update({'ESS':{}, 'r_hat':{}})

In [22]:
for d in ['ESS', 'r_hat']:
    for par in params:
        avg_d_par = []
        for s in range(n_sims):
            avg_d_par.append(diagnostics_all[s][d][par])
        avg_d_par = jnp.array(avg_d_par).mean(0)  
        avg_diagnostics[d].update({par:avg_d_par})

In [23]:
with open(data_save_path + f'diagnostics_ss_A85semi_avg_p{p}_n{n_cut}.sav' , 'wb') as f:
    pickle.dump((avg_diagnostics), f)
    