## Limiting the number of cores that NumPyro can use

https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy?fbclid=IwAR3KpUvkrqQR5DR74zKoMFATxkFz_LxNOCoT2dCHNPEAEgohEHMDydSUyEU

In [1]:
#!pip install threadpoolctl

In [1]:
from threadpoolctl import threadpool_limits
from threadpoolctl import threadpool_info
from pprint import pprint
import numpy


[{'filepath': '/home/usuario/anaconda3/lib/python3.8/site-packages/numpy.libs/libopenblasp-r0-09e95953.3.13.so',
  'internal_api': 'openblas',
  'num_threads': 16,
  'prefix': 'libopenblas',
  'threading_layer': 'pthreads',
  'user_api': 'blas',
  'version': '0.3.13'}]


In [2]:
import os

os.chdir('/home/usuario/Documents/Barcelona_Yr1/GraphicalModels_NetworkData/LiLicode/paper_code_github/')

### Import modules, utils

In [3]:
import sys
#sys.path.append("./graphical-regression-with-networks/numpyro/functions")
sys.path.append("functions")

In [4]:
data_save_path = './Data/Simulations/'

sim_data_path = './Data/Simulations/'


In [5]:
import models
import my_utils

cpu
Is 64 precision enabled?: True
cpu
Is 64 precision enabled?: True


In [6]:
import importlib
importlib.reload(models)
importlib.reload(my_utils)

cpu
Is 64 precision enabled?: True
cpu
Is 64 precision enabled?: True


<module 'my_utils' from 'functions/my_utils.py'>

In [7]:
# imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from numpyro.util import enable_x64
import time
from datetime import timedelta

In [8]:
import jax
import numpyro
# numpyro.set_platform('gpu')
print(jax.lib.xla_bridge.get_backend().platform)

from jax import random, vmap
import jax.numpy as jnp

import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, log_likelihood
from jax.random import PRNGKey as Key
from numpyro.infer import init_to_feasible, init_to_value
from numpyro.diagnostics import print_summary
from numpyro.handlers import condition, substitute, block, seed
from optax import adam, exponential_decay
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta


cpu


In [9]:
enable_x64(use_x64=True)
print("Is 64 precision enabled?:", jax.config.jax_enable_x64)

Is 64 precision enabled?: True


In [10]:
n_sims = 10
p = 10
n = 2000
n_cut = 100
TP_thresh = 3

# Run full model with MCMC, reparametrisation and is_dense=False

In [11]:
# params
n_warmup = 2000
n_samples = 2000

mu_m=0.
mu_s=1.

#my_model = models.golazo_ss_repr
my_model = models.golazo_ss_repr_etaRepr
is_dense=False
verbose=False

estimates_print = ["w_slab", "mean_slab", "scale_slab"]

In [12]:
mu_fixed=jnp.zeros((p,))
scale_spike_fixed=0.003

In [13]:
# 85semidep params

eta0_0_m_85SEMI=0.
eta0_0_s_85SEMI=0.126
eta0_coefs_m_85SEMI=0.
eta0_coefs_s_85SEMI=0.126

eta1_0_m_85SEMI=-2.197
eta1_0_s_85SEMI=0.4
eta1_coefs_m_85SEMI=0.
eta1_coefs_s_85SEMI=0.4

eta2_0_m_85SEMI=-2.444
eta2_0_s_85SEMI=1.944
eta2_coefs_m_85SEMI=0.
eta2_coefs_s_85SEMI=1.944


In [14]:
# init strategy
#rho_init = jnp.diag(jnp.ones((p,)))
rho_tilde_init = jnp.zeros((int(p*(p-1)/2),))
u_init = jnp.ones((int(p*(p-1)/2),))*0.5

mu_init = jnp.zeros((p,))
sqrt_diag_init = jnp.ones((p,))

my_init_strategy_85SEMI = init_to_value(values={#'rho':rho_init,
                                                'rho_tilde':rho_tilde_init,
                                                'u':u_init,
                                                'mu':mu_init, 
                                                'sqrt_diag':sqrt_diag_init, 
                                                'tilde_eta0_0':0.0,
                                                'tilde_eta1_0':0.0,
                                                'tilde_eta2_0':0.0,
                                                'tilde_eta0_coefs':jnp.array([0.0]),
                                                'tilde_eta1_coefs':jnp.array([0.0]),
                                                'tilde_eta2_coefs':jnp.array([0.0]),})


#my_init_strategy = init_to_uniform()



In [27]:
diagnostics_all = {}
with threadpool_limits(limits=6, user_api='blas'):
    for s in range(n_sims):
        print('--------------------------------------------------------------------------------')
        print(f" Simulation number: {s} \n Dimensions: p = {p}, n = {n_cut} \n Run Network-SS with A-85SEMIDEP")
        with open(sim_data_path + f'sim{s}_p{p}_n{n}.sav', 'rb') as fr:
            sim_res = pickle.load(fr)

        A_list = [jnp.array(sim_res["A_scaled_semi_dep85"])]
        my_model_args = {"A_list":A_list, "eta0_0_m":eta0_0_m_85SEMI, "eta0_0_s":eta0_0_s_85SEMI, 
                 "eta0_coefs_m":eta0_coefs_m_85SEMI, "eta0_coefs_s":eta0_coefs_s_85SEMI,
                 "eta1_0_m":eta1_0_m_85SEMI, "eta1_0_s":eta1_0_s_85SEMI, 
                 "eta1_coefs_m":eta1_coefs_m_85SEMI, "eta1_coefs_s":eta1_coefs_s_85SEMI,
                 "eta2_0_m":eta2_0_m_85SEMI, "eta2_0_s":eta2_0_s_85SEMI, 
                 "eta2_coefs_m":eta2_coefs_m_85SEMI, "eta2_coefs_s":eta2_coefs_s_85SEMI,
                 "mu_m":mu_m, "mu_s":mu_s} 
    
        # select data 
        Y = jnp.array(sim_res['Y'])
        Y = Y[:n_cut,:]
        theta_true = jnp.array(sim_res['theta_true'])
        tril_idx = jnp.tril_indices(n=p, k=-1, m=p)
        nonzero_true = (jnp.abs(theta_true[tril_idx]) != 0.)

        # set model
        fixed_params_dict = {"scale_spike":scale_spike_fixed, 
                         "mu":mu_fixed}
        blocked_params_list = ["scale_spike", "mu"]
        my_model_run = block(condition(my_model, fixed_params_dict), 
                         hide=blocked_params_list)


        nuts_kernel = NUTS(my_model_run, init_strategy=my_init_strategy_85SEMI, 
                       dense_mass=is_dense)
    
    
        # run model and time up
        start_time = time.time()
        mcmc = MCMC(nuts_kernel, num_warmup=n_warmup, num_samples=n_samples)
        mcmc.run(rng_key = Key(s+44), Y=Y, **my_model_args,
            extra_fields=('potential_energy','accept_prob', 
                          'num_steps', 'adapt_state'))

        end_time = time.time()
        seconds_elapsed = end_time - start_time

        print(str(timedelta(seconds=seconds_elapsed)))

        # save samples
        res_all_samples = mcmc.get_samples()

        # record diagnostics
        params = ['eta0_0', 'eta0_coefs', 'eta1_0', 'eta1_coefs',
             'eta2_0', 'eta2_coefs', 'rho_tilde', 'rho_lt', 'sqrt_diag']

        diagnostics_dict = {'ESS':{}, 'r_hat':{}}

        for par in params:
            ESS = numpyro.diagnostics.summary(jnp.expand_dims(res_all_samples[par],
                                                          0))['Param:0']['n_eff']

            diagnostics_dict['ESS'][par] = ESS

            r_hat = numpyro.diagnostics.summary(jnp.expand_dims(res_all_samples[par],
                                                            0))['Param:0']['r_hat']

            diagnostics_dict['r_hat'][par] = r_hat

        diagnostics_dict.update({'potential_energy':mcmc.get_extra_fields()['potential_energy']})
        diagnostics_dict.update({'seconds_elapsed':seconds_elapsed})


        # save
        with open(data_save_path + f'diagnostics_ss_A85semi_{s}_p{p}_n{n_cut}.sav' , 'wb') as f:
            pickle.dump((diagnostics_dict), f)

        if verbose:
            # output indicators
            prob_slab_all = []
            for cs in range(n_samples):
                prob_slab = my_utils.get_prob_slab(rho_lt=res_all_samples['rho_lt'][cs], 
                                            mean_slab=res_all_samples['mean_slab'][cs], 
                                            scale_slab=res_all_samples['scale_slab'][cs], 
                                            scale_spike=fixed_params_dict['scale_spike'], 
                                            w_slab=res_all_samples['w_slab'][cs], 
                                            w_spike=(1-res_all_samples['w_slab'])[cs])
                prob_slab_all.append(prob_slab)
            prob_slab_est = (jnp.array(prob_slab_all)).mean(0)
            print('prob_slab_est', prob_slab_est.shape)

            nonzero_preds_5 = (prob_slab_est>0.5).astype(int)
            print(f'is_nonzero with thresh {0.5}', nonzero_preds_5)
            nonzero_preds_95 = (prob_slab_est>0.95).astype(int)
            print(f'is_nonzero with thresh {0.95}', nonzero_preds_95)


            TP_5 = jnp.where((nonzero_preds_5 == True)&(nonzero_true == True))[0].shape[0]
            FP_5 = jnp.where((nonzero_preds_5 == True)&(nonzero_true == False))[0].shape[0]
            FN_5 = jnp.where((nonzero_preds_5 == False)&(nonzero_true == True))[0].shape[0]
            TN_5 = jnp.where((nonzero_preds_5 == False)&(nonzero_true == False))[0].shape[0]

            TP_95 = jnp.where((nonzero_preds_95 == True)&(nonzero_true == True))[0].shape[0]
            FP_95 = jnp.where((nonzero_preds_95 == True)&(nonzero_true == False))[0].shape[0]
            FN_95 = jnp.where((nonzero_preds_95 == False)&(nonzero_true == True))[0].shape[0]
            TN_95 = jnp.where((nonzero_preds_95 == False)&(nonzero_true == False))[0].shape[0]

            theta = res_all_samples['theta'].mean(0)
            print('mse on estimated posterior mean for theta vs true theta: ', my_utils.get_MSE(theta, theta_true))
            print(" ")

            for k in estimates_print:
                est_mean = res_all_samples[f'{k}'].mean(0)
                print(f'estimated mean posterior for {k} is {est_mean}')

            print(" ") 
            print(f'Lower triangle of estimated theta is {theta[tril_idx]}')
            print(" ")
            print(f'Total for threshold 0.5:{p*(p-1)/2}, TP:{TP_5}, FP:{FP_5}, FN:{FN_5}, TN:{TN_5}')

            TPR = my_utils.get_TPR(TP=TP_5, FN=FN_5)
            FPR = my_utils.get_FPR(FP=FP_5, TN=TN_5)
            FNR = my_utils.get_FNR(FN=FN_5, TP=TP_5)
            print(f'TPR: {TPR}, FPR: {FPR}, FNR: {FNR}')
            try:
                FDiscR = my_utils.get_FDiscR(FP=FP_5, TP=TP_5)
                print(f'FDiscR: {FDiscR}')
            except:
                print('FDiscR: N/A, no positives')
            try:
                FNonDiscR = my_utils.get_FNonDiscR(TN=TN_5, FN=FN_5)
                print(f'FNonDiscR: {FNonDiscR}')
            except:
                print('FNonDiscR: N/A, no negatives')
        
        diagnostics_all[s] = diagnostics_dict
    with open(data_save_path + f'diagnostics_ss_A85semi_all_p{p}_n{n_cut}.sav' , 'wb') as f:
        pickle.dump((diagnostics_all), f)

  0%|          | 0/4000 [00:00<?, ?it/s]

--------------------------------------------------------------------------------
 Simulation number: 0 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:34<00:00, 114.93it/s, 127 steps of size 1.88e-02. acc. prob=0.80]
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:35.682493
--------------------------------------------------------------------------------
 Simulation number: 1 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:53<00:00, 74.56it/s, 511 steps of size 1.02e-02. acc. prob=0.90]  
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:54.107589
--------------------------------------------------------------------------------
 Simulation number: 2 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:38<00:00, 103.46it/s, 255 steps of size 1.54e-02. acc. prob=0.70]


0:00:39.287269
--------------------------------------------------------------------------------
 Simulation number: 3 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:40<00:00, 98.43it/s, 255 steps of size 1.88e-02. acc. prob=0.42] 
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:41.174495
--------------------------------------------------------------------------------
 Simulation number: 4 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:52<00:00, 76.51it/s, 511 steps of size 1.04e-02. acc. prob=0.91]  
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:52.744007
--------------------------------------------------------------------------------
 Simulation number: 5 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:57<00:00, 69.47it/s, 511 steps of size 8.86e-03. acc. prob=0.92]  


0:00:58.342351
--------------------------------------------------------------------------------
 Simulation number: 6 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:56<00:00, 71.22it/s, 511 steps of size 1.03e-02. acc. prob=0.85]  
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:56.758653
--------------------------------------------------------------------------------
 Simulation number: 7 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:42<00:00, 94.49it/s, 255 steps of size 1.39e-02. acc. prob=0.89] 
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:42.830770
--------------------------------------------------------------------------------
 Simulation number: 8 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:53<00:00, 74.19it/s, 511 steps of size 1.13e-02. acc. prob=0.88] 
  0%|          | 0/4000 [00:00<?, ?it/s]

0:00:54.519362
--------------------------------------------------------------------------------
 Simulation number: 9 
 Dimensions: p = 10, n = 100 
 Run Network-SS with A-85SEMIDEP


sample: 100%|██████████| 4000/4000 [00:35<00:00, 113.37it/s, 72 steps of size 2.22e-02. acc. prob=0.60] 


0:00:36.033659


# Average diagnostics across simulations

In [28]:
# load diagnostics_all
with open(data_save_path + f'diagnostics_ss_A85semi_all_p{p}_n{n_cut}.sav', 'rb') as fr:
    diagnostics_all = pickle.load(fr)

In [29]:
params = ['eta0_0', 'eta0_coefs', 'eta1_0', 'eta1_coefs',
         'eta2_0', 'eta2_coefs', 'rho_lt']


In [30]:
avg_diagnostics = {}
for d in ['potential_energy', 'seconds_elapsed']:
    avg_d = []
    for s in range(n_sims):
        avg_d.append(diagnostics_all[s][d])

    avg_d = jnp.array(avg_d).mean(0)
    avg_diagnostics[d] = avg_d
    

In [31]:
avg_diagnostics.update({'ESS':{}, 'r_hat':{}})

In [32]:
for d in ['ESS', 'r_hat']:
    for par in params:
        avg_d_par = []
        for s in range(n_sims):
            avg_d_par.append(diagnostics_all[s][d][par])
        avg_d_par = jnp.array(avg_d_par).mean(0)  
        avg_diagnostics[d].update({par:avg_d_par})

In [34]:
with open(data_save_path + f'diagnostics_ss_A85semi_avg_p{p}_n{n_cut}.sav' , 'wb') as f:
    pickle.dump((avg_diagnostics), f)
    

In [35]:
(avg_diagnostics['ESS']['eta0_0'] + avg_diagnostics['ESS']['eta0_coefs'] + 
 avg_diagnostics['ESS']['eta1_0'] + avg_diagnostics['ESS']['eta1_coefs'] + 
 avg_diagnostics['ESS']['eta2_0'] + avg_diagnostics['ESS']['eta2_coefs'])/6

DeviceArray([600.93242441], dtype=float64)

In [36]:
avg_diagnostics['ESS']['rho_lt'].mean()

DeviceArray(1055.68338941, dtype=float64)