In [9]:
import numpy as np
import pandas as pd
import sys
import os
from codebase.file_utils import save_obj, load_obj
from codebase.post_process import samples_to_df, get_post_df, remove_cn_dimension
import altair as alt
alt.data_transformers.disable_max_rows()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
ppp = True
log_dir = "./log/20211011_191649_ez-sim/"
data = load_obj('data1', log_dir)


In [11]:
data

{'random_seed': 1,
 'N': 10,
 'K': 2,
 'J': 6,
 'alpha': array([0., 0., 0., 0., 0., 0.]),
 'beta': array([[1. , 0. ],
        [0.8, 0. ],
        [0.8, 0. ],
        [0. , 1. ],
        [0. , 0.8],
        [0. , 0.8]]),
 'sigma_z': array([1., 1.]),
 'Phi_corr': array([[1. , 0.2],
        [0.2, 1. ]]),
 'Phi_cov': array([[1. , 0.2],
        [0.2, 1. ]]),
 'z': array([[-0.87130378, -1.64512124],
        [ 1.08772502, -0.26948486],
        [ 0.78527901, -2.12596275],
        [-0.87009587, -1.8329549 ],
        [-0.08941095, -0.40484229],
        [ 0.17040345, -2.43549133],
        [ 0.49264059,  0.00684599],
        [-0.18258172, -1.57384635],
        [ 0.68876873, -0.4216441 ],
        [-0.40130333,  0.33590608]]),
 'y': array([[-2.72421697,  1.23828688,  0.79485403, -0.83403896,  0.17449788,
         -2.43037866],
        [ 0.89148573, -0.68252477,  0.44129552,  0.58767656, -1.3433453 ,
         -0.85325266],
        [-0.33485216, -0.76434535, -0.4648933 , -2.14617267, -3.58476386,
    

In [17]:
def clean_samples(ps0):
    ps = ps0.copy()
    num_chains = ps['alpha'].shape[1]
    num_samples = ps['alpha'].shape[0]
    for chain_number in range(ps['alpha'].shape[1]):
        for i in range(num_samples):
            sign1 = np.sign(ps['beta'][i,chain_number,0,0])
            sign2 = np.sign(ps['beta'][i,chain_number,3,1])
            ps['beta'][i,chain_number,:3,0] = ps['beta'][i,chain_number,:3,0] * sign1
            ps['beta'][i,chain_number,3:,1] = ps['beta'][i,chain_number,3:,1] * sign2

            if 'Phi_cov' in ps.keys():
                ps['Phi_cov'][i,chain_number,0,1] = sign1 * sign2 * ps['Phi_cov'][i,chain_number,0,1]
                ps['Phi_cov'][i,chain_number,1,0] = ps['Phi_cov'][i,chain_number,0,1]
    
    return ps

def get_point_estimates(ps0, param_name, estimate_name):
    ps = remove_cn_dimension(
        clean_samples(ps0)[param_name]
    )
    if estimate_name == 'mean':
        return np.mean(ps,axis=0)
    elif estimate_name == 'median':
        return np.median(ps, axis=0)

    
def get_credible_interval_beta(ps):
    df = get_post_df(ps)
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    df = df_quant.merge(df_quant2, on=['row', 'col'])

    dd = pd.DataFrame(data['beta'], columns=['0', '1'])
    dd['row'] = np.arange(dd.shape[0])
    dd = dd.melt(id_vars='row', var_name='col', value_name = 'data')
    dd['col'] = dd.col.astype(int)

    plot_data = df.merge(dd, on=['row', 'col'])
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )


def get_credible_interval_Phi(ps):
    df = get_post_df(ps)
    df = df[(df.row==0)&(df.col==1)]
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    plot_data = df_quant.merge(df_quant2, on=['row', 'col'])
    plot_data['data'] = data['Phi_cov'][0,1]
    plot_data
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )

In [13]:

param = 'beta'
estimate_name = 'mean'

nsim = 2
estimates = np.empty((nsim, 6, 2))
for i in range(nsim):
    estimates[i] = get_point_estimates(
        load_obj('ps'+str(i), log_dir),
        param,
        estimate_name
    )
estimates

array([[[ 1.        ,  0.        ],
        [ 0.37829298,  0.        ],
        [ 0.0339923 ,  0.        ],
        [ 0.        ,  1.        ],
        [ 0.        , -0.01007444],
        [ 0.        ,  0.16247398]],

       [[ 1.        ,  0.        ],
        [-0.15113563,  0.        ],
        [-0.04792802,  0.        ],
        [ 0.        ,  1.        ],
        [ 0.        , -0.13875943],
        [ 0.        ,  0.15891921]]])

In [14]:
get_credible_interval_beta(estimates)

In [15]:

param = 'Phi_cov'
estimate_name = 'mean'

nsim = 2
estimates = np.empty((nsim, 2, 2))
for i in range(nsim):
    estimates[i] = get_point_estimates(
        load_obj('ps'+str(i), log_dir),
        param,
        estimate_name
    )
estimates

array([[[0.15253119, 0.01709247],
        [0.01709247, 0.13903847]],

       [[0.16296468, 0.0079389 ],
        [0.0079389 , 0.14056054]]])

In [18]:
get_credible_interval_Phi(estimates)