In [1]:
import sys
import os
os.chdir("/home/usuario/Documents/Barcelona_Yr1/GraphicalModels_NetworkData/LiLicode/paper_code_github")
# sys.path.append("functions")
data_save_path = './Data/Simulations/'

In [2]:
# import models
# import my_utils

In [2]:
# imports
# !python -m spacy download en_core_web_sm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import json

In [3]:
import jax
import numpyro
# numpyro.set_platform('gpu')
print(jax.lib.xla_bridge.get_backend().platform)

import jax.numpy as jnp
import numpyro.distributions as dist
from jax.random import PRNGKey as Key
from numpyro.util import enable_x64

cpu


In [4]:
def generate_theta(A_true, w_slab_1, w_slab_0, slab_0_low, slab_0_up,  slab_1_low, slab_1_up, key_no):
    
    # select lower-triangular A
    p = A_true.shape[0]
    tril_idx = jnp.tril_indices(n=p, k=-1, m=p)
    A_tril = A_true[tril_idx]


    A_0_idx = jnp.where(A_tril==0.)[0]
    A_1_idx = jnp.where(A_tril==1.)[0]


    # generate theta vals for A=1  
    ones_no = int(A_tril.sum())
    slab_1_no = int(ones_no*w_slab_1)

    spike_1_no = ones_no - slab_1_no
    spike_1_idx = jax.random.choice(Key(key_no+5), jnp.arange(ones_no), (spike_1_no,), replace=False)

    potential_slab_1 = jnp.linspace(slab_1_low, slab_1_up, ones_no)
    theta_1 = potential_slab_1.at[spike_1_idx].set(0.)

    # generate theta vals for A = 0
    zeros_no = int((A_tril==0.).sum())
    slab_0_no = int(zeros_no*w_slab_0)

    spike_0_no = zeros_no - slab_0_no
    spike_0_idx = jax.random.choice(Key(key_no+4), jnp.arange(zeros_no), (spike_0_no,), replace=False)

    potential_slab_0 = jnp.linspace(slab_0_low, slab_0_up, zeros_no)
    theta_0 = potential_slab_0.at[spike_0_idx].set(0.)


    # combine to obtain theta
    theta_tril = A_tril.at[A_0_idx].set(theta_0)
    theta_tril = theta_tril.at[A_1_idx].set(theta_1)

    my_theta_init = jnp.diag(jnp.zeros((p,)))
    my_theta_lt = my_theta_init.at[tril_idx].set(theta_tril)

    my_theta = my_theta_lt + my_theta_lt.T + jnp.diag(jnp.ones((p,)))
    return my_theta

In [5]:
def simulate_data(n_obs, p, mu_true, theta_true, key_no):

    sim_res = {}

    Y = dist.MultivariateNormal(mu_true, 
                                precision_matrix=theta_true).expand((n_obs,)).sample(Key(key_no))
    sim_res['mu_true'] = mu_true
    sim_res['theta_true'] = theta_true
    sim_res['Y'] = Y
    sim_res['n'] = n_obs
    
    return sim_res

In [6]:
def network_simulate(p, A_true, flip_prop=0.):
    
    triu_idx = jnp.triu_indices(n=p, k=1, m=p)
    A_lt = A_true
    A_lt = A_lt.at[triu_idx].set(-999)
    
    pos = jnp.array(jnp.where(A_lt>0.))
    flip_pos = get_flip_idx(pos, prop=flip_prop)
    print(f'Flipping {flip_pos[0].shape[0]} positives out of {pos[0].shape[0]} to zero')
    
    zeros = jnp.array(jnp.where(A_lt==0.))
    flip_zeros = get_flip_idx(zeros, prop=flip_prop)
    print(f'Flipping {flip_zeros[0].shape[0]} zeros out of {zeros[0].shape[0]} to one')
    
    A_new = A_true.at[flip_pos].set(0.)
    A_new = A_new.at[flip_zeros].set(1.)
    
    A_new = jnp.tril(A_new) + jnp.tril(A_new).T - jnp.diag(jnp.diag(A_new))
    return A_new


def network_scale(A):
    p = A.shape[0]
    n_obs = p*(p-1)/2
    A_lt = jnp.tril(A, k=-1)
    A_m = A_lt.sum()/n_obs
    A_bar_lt = jnp.tril(A_lt - A_m, k=-1)
    A_var = ((A_bar_lt)**2).sum()/(n_obs)
    
    A_scaled = (A-A_m)/jnp.sqrt(A_var)
    return A_scaled


def get_flip_idx(coordinates, prop=0.6):
    n_pos = coordinates.shape[1]
    n_flip = np.int(n_pos*prop)
    
    flip_idx = jax.random.choice(Key(5), jnp.arange(n_pos), (n_flip,), replace=False)
    flip = coordinates[:,flip_idx]
    flip = (flip[0], flip[1])
    
    return flip

In [7]:
p_list = [10, 20, 50, 100]
n = 2000
n_sims = 50
w_slab_1 = 0.95
w_slab_0_numerator = 0.5
w_slab_1, w_slab_0_numerator

slab_0_low=-0.1
slab_0_up=0.1
slab_1_low=0.2
slab_1_up=0.5

In [12]:
for p in p_list:
    print('--------------------------------------------------------------------------------')
    print(f"Dimensions: n={n}, p={p}")
    print('--------------------------------------------------------------------------------')
    
    w_slab_0 = w_slab_0_numerator/p


    mu_true = jnp.zeros(p)
    offset_1 = jnp.ones((p-1,))
    
    print(" ")
    print("Generate A true...")
    A_true = jnp.diag(offset_1,1) + jnp.diag(offset_1,-1) + jnp.diag(jnp.ones((p,)))
    A_scaled_true = network_scale(A_true)

    print(" ")
    print("Generate A indep...")
    A_indep = network_simulate(p=p, A_true=A_true, flip_prop=0.5)
    A_scaled_indep = network_scale(A_indep)

    print(" ")
    print("Generate A semi-dep...")
    A_semi_dep = network_simulate(p=p,  A_true=A_true, flip_prop=0.25)
    A_scaled_semi_dep = network_scale(A_semi_dep)
    
    print(" ")
    print("Generate A semi-dep85...")
    A_semi_dep85 = network_simulate(p=p,  A_true=A_true, flip_prop=0.15)
    A_scaled_semi_dep85 = network_scale(A_semi_dep85)

    print(" ")
    print("Generate theta...")
    theta_true = generate_theta(A_true=A_true, w_slab_1=w_slab_1, w_slab_0=w_slab_0, 
                                slab_0_low=slab_0_low, slab_0_up=slab_0_up,  
                                slab_1_low=slab_1_low, slab_1_up=slab_1_up,
                                key_no=p+2)

    print(" ")
    for s in range(n_sims):
        print(f"Simulation number: {s}")

        sim_res = simulate_data(n_obs=n, p=p, mu_true=mu_true, theta_true=theta_true, key_no=p+s+5)



        sim_res.update({"A_indep":A_indep, "A_scaled_indep":A_scaled_indep, 
                        "A_semi_dep":A_semi_dep, "A_scaled_semi_dep":A_scaled_semi_dep,
                        "A_semi_dep85":A_semi_dep85, "A_scaled_semi_dep85":A_scaled_semi_dep85,
                        "A_true":A_true, "A_scaled_true":A_scaled_true})
        


        # uncomment to save to JSON
        sim_res_json = {k:np.array(v).tolist() for k,v in sim_res.items()}
        with open(data_save_path + f'sim{s}_p{p}_n{n}.json' , 'w') as f:
             json.dump((sim_res_json), f)
        
        # comment to save to JSON ONLY
        with open(data_save_path + f'sim{s}_p{p}_n{n}.sav' , 'wb') as f:
            pickle.dump((sim_res), f)

--------------------------------------------------------------------------------
Dimensions: n=2000, p=10
--------------------------------------------------------------------------------
 
Generate A true...
 
Generate A indep...
Flipping 9 positives out of 19 to zero
Flipping 18 zeros out of 36 to one
 
Generate A semi-dep...
Flipping 4 positives out of 19 to zero
Flipping 9 zeros out of 36 to one
 
Generate A semi-dep85...
Flipping 2 positives out of 19 to zero
Flipping 5 zeros out of 36 to one
 
Generate theta...
 
Simulation number: 0
Simulation number: 1
Simulation number: 2
Simulation number: 3
Simulation number: 4
Simulation number: 5
Simulation number: 6
Simulation number: 7
Simulation number: 8
Simulation number: 9
Simulation number: 10
Simulation number: 11
Simulation number: 12
Simulation number: 13
Simulation number: 14
Simulation number: 15
Simulation number: 16
Simulation number: 17
Simulation number: 18
Simulation number: 19
Simulation number: 20
Simulation number: 21
Si