In [48]:
import numpy as np
import pandas as pd
import sys
import os
from codebase.file_utils import save_obj, load_obj
from codebase.post_process import samples_to_df, get_post_df, remove_cn_dimension
import altair as alt
alt.data_transformers.disable_max_rows()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
def clean_samples(ps0):
    ps = ps0.copy()
    num_chains = ps['alpha'].shape[1]
    num_samples = ps['alpha'].shape[0]
    for chain_number in range(ps['alpha'].shape[1]):
        for i in range(num_samples):
            sign1 = np.sign(ps['beta'][i,chain_number,0,0])
            sign2 = np.sign(ps['beta'][i,chain_number,3,1])
            ps['beta'][i,chain_number,:3,0] = ps['beta'][i,chain_number,:3,0] * sign1
            ps['beta'][i,chain_number,3:,1] = ps['beta'][i,chain_number,3:,1] * sign2

            if 'Phi_cov' in ps.keys():
                ps['Phi_cov'][i,chain_number,0,1] = sign1 * sign2 * ps['Phi_cov'][i,chain_number,0,1]
                ps['Phi_cov'][i,chain_number,1,0] = ps['Phi_cov'][i,chain_number,0,1]
    
    return ps

def get_point_estimates(ps0, param_name, estimate_name):
    ps = remove_cn_dimension(
        clean_samples(ps0)[param_name]
    )
    if estimate_name == 'mean':
        return np.mean(ps,axis=0)
    elif estimate_name == 'median':
        return np.median(ps, axis=0)

    
def get_credible_interval_beta(ps):
    df = get_post_df(ps)
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    df = df_quant.merge(df_quant2, on=['row', 'col'])

    dd = pd.DataFrame(data['beta'], columns=['0', '1'])
    dd['row'] = np.arange(dd.shape[0])
    dd = dd.melt(id_vars='row', var_name='col', value_name = 'data')
    dd['col'] = dd.col.astype(int)

    plot_data = df.merge(dd, on=['row', 'col'])
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )


def get_credible_interval_Phi(ps):
    df = get_post_df(ps)
    df = df[(df.row==0)&(df.col==1)]
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    plot_data = df_quant.merge(df_quant2, on=['row', 'col'])
    plot_data['data'] = data['Phi_cov'][0,1]
    plot_data
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )

In [26]:
log_dir = "./log/fabian-freq-estimators/bin/m2/"
data = load_obj('data2', log_dir)
data

{'random_seed': 2,
 'N': 2000,
 'K': 2,
 'J': 6,
 'alpha': array([0., 0., 0., 0., 0., 0.]),
 'beta': array([[1. , 0. ],
        [0.8, 0. ],
        [0.8, 0. ],
        [0. , 1. ],
        [0. , 0.8],
        [0. , 0.8]]),
 'sigma_z': array([1., 1.]),
 'Phi_corr': array([[1., 0.],
        [0., 1.]]),
 'Phi_cov': array([[1., 0.],
        [0., 1.]]),
 'z': array([[-0.41675785, -0.05626683],
        [-2.1361961 ,  1.64027081],
        [-1.79343559, -0.84174737],
        ...,
        [-0.97975806,  0.8457959 ],
        [-0.82765158,  1.98027514],
        [ 1.80837527, -0.31915924]]),
 'y': array([[-0.41675785, -0.33340628, -0.33340628, -0.05626683, -0.04501346,
         -0.04501346],
        [-2.1361961 , -1.70895688, -1.70895688,  1.64027081,  1.31221665,
          1.31221665],
        [-1.79343559, -1.43474847, -1.43474847, -0.84174737, -0.67339789,
         -0.67339789],
        ...,
        [-0.97975806, -0.78380645, -0.78380645,  0.8457959 ,  0.67663672,
          0.67663672],
        

In [28]:
for i in range(20):
    try:
        qtl = load_obj('q_beta'+str(i), log_dir)
        results = (qtl[0] <= data['beta'])&(qtl[1] >= data['beta'])
        p = results
        print(i,'\n', p)
    except:
        print(i,'\n','no data')

0 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
1 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
2 
 [[ True  True]
 [False  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
3 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
4 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
5 
 no data
6 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
7 
 no data
8 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
9 
 no data
10 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
11 
 [[ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]
 [ True  True]]
12 
 [[ True  True]
 [ True  True]
 [False  True]
 [ True  True]
 [ True  True]
 [ True  True]]
13 
 [[ Tru

In [23]:
for i in range(20):
    try:
        qtl = load_obj('q_Phi_cov'+str(i), log_dir)
        results = (qtl[0] <= data['Phi_cov'])&(qtl[1] >= data['Phi_cov'])
        p = results
        print(i,'\n', p)
    except:
        print(i,'\n','no data')

0 
 [[ True  True]
 [ True  True]]
1 
 [[ True  True]
 [ True  True]]
2 
 [[ True  True]
 [ True  True]]
3 
 [[ True  True]
 [ True  True]]
4 
 [[ True  True]
 [ True  True]]
5 
 no data
6 
 [[ True  True]
 [ True  True]]
7 
 no data
8 
 [[ True  True]
 [ True  True]]
9 
 no data
10 
 [[ True  True]
 [ True  True]]
11 
 [[ True  True]
 [ True  True]]
12 
 [[ True  True]
 [ True  True]]
13 
 [[ True  True]
 [ True  True]]
14 
 [[ True  True]
 [ True  True]]
15 
 [[ True  True]
 [ True  True]]
16 
 [[ True  True]
 [ True  True]]
17 
 [[ True  True]
 [ True  True]]
18 
 [[ True  True]
 [ True  True]]
19 
 [[ True  True]
 [ True  True]]


In [60]:
seed = 12
ps0 = load_obj('ps'+str(seed), log_dir)
data = load_obj('data'+str(seed), log_dir)

def test_coverage_of_selected_chains(ps0, chain_list):
    ps = dict()
    ps['alpha'] = ps0['alpha'][:,chain_list]
    ps['beta'] = ps0['beta'][:,chain_list]
    ps['Phi_cov'] = ps0['Phi_cov'][:,chain_list]
    ps =  clean_samples(ps)
    quant = dict()
    quant['beta'] = np.quantile(
        remove_cn_dimension(ps['beta']),
        [0.025, 0.975],
        axis=0
        )
    quant['alpha'] = np.quantile(
        remove_cn_dimension(ps['alpha']),
        [0.025, 0.975],
        axis=0
        )
    quant['Phi_cov'] = np.quantile(
        remove_cn_dimension(ps['Phi_cov']),
        [0.025, 0.975],
        axis=0
        )

    return quant
quant = test_coverage_of_selected_chains(ps0, [2,1])

qtl = quant['beta']
results = (qtl[0] <= data['beta'])&(qtl[1] >= data['beta'])
p = results
print(seed)
print(p)

qtl = quant['Phi_cov']
results = (qtl[0] <= data['Phi_cov'])&(qtl[1] >= data['Phi_cov'])
p = results
print(p)

qtl = quant['alpha']
results = (qtl[0] <= data['alpha'])&(qtl[1] >= data['alpha'])
p = results
print(p)


12
[[ True  True]
 [ True  True]
 [False  True]
 [ True  True]
 [ True  True]
 [ True  True]]
[[ True  True]
 [ True  True]]
[ True  True  True  True  True  True]


In [61]:
estimates = clean_samples(load_obj('ps12', log_dir))['beta']
get_credible_interval_beta(estimates)

In [62]:
estimates = clean_samples(load_obj('ps12', log_dir))['Phi_cov']
get_credible_interval_Phi(estimates)