In [1]:
import numpy as np
import pandas as pd
import sys
import os
from codebase.file_utils import save_obj, load_obj
from codebase.post_process import samples_to_df, get_post_df, remove_cn_dimension
import altair as alt
alt.data_transformers.disable_max_rows()

%load_ext autoreload
%autoreload 2

In [3]:
ppp = True
log_dir = "./log/fabian-freq-estimators/"
data = load_obj('data0', log_dir)


In [4]:
data

{'random_seed': 0,
 'N': 1000,
 'K': 2,
 'J': 6,
 'alpha': array([0., 0., 0., 0., 0., 0.]),
 'beta': array([[1. , 0. ],
        [0.8, 0. ],
        [0.8, 0. ],
        [0. , 1. ],
        [0. , 0.8],
        [0. , 0.8]]),
 'sigma_z': array([1., 1.]),
 'Phi_corr': array([[1. , 0.2],
        [0.2, 1. ]]),
 'Phi_cov': array([[1. , 0.2],
        [0.2, 1. ]]),
 'z': array([[-1.61951071, -1.11334743],
        [-2.17539248,  0.65913812],
        [-0.8285194 , -2.064689  ],
        ...,
        [-0.21506008, -0.091414  ],
        [-1.18581773, -0.98541301],
        [ 1.7136435 ,  0.05538257]]),
 'y': array([[-4.32504373, -4.38730471, -1.22198053, -2.70652732, -1.01967297,
         -2.03919375],
        [-3.43862813, -2.51570901, -0.59226408,  2.20002122,  1.12333169,
         -1.12253058],
        [ 0.21793111,  0.46024436, -0.22303659, -3.04390179, -0.49471265,
         -0.97153843],
        ...,
        [-2.33686579,  0.15087484, -0.66135544, -2.12297416, -2.8736911 ,
          1.32768299],


In [5]:
def clean_samples(ps0):
    ps = ps0.copy()
    num_chains = ps['alpha'].shape[1]
    num_samples = ps['alpha'].shape[0]
    for chain_number in range(ps['alpha'].shape[1]):
        for i in range(num_samples):
            sign1 = np.sign(ps['beta'][i,chain_number,0,0])
            sign2 = np.sign(ps['beta'][i,chain_number,3,1])
            ps['beta'][i,chain_number,:3,0] = ps['beta'][i,chain_number,:3,0] * sign1
            ps['beta'][i,chain_number,3:,1] = ps['beta'][i,chain_number,3:,1] * sign2

            if 'Phi_cov' in ps.keys():
                ps['Phi_cov'][i,chain_number,0,1] = sign1 * sign2 * ps['Phi_cov'][i,chain_number,0,1]
                ps['Phi_cov'][i,chain_number,1,0] = ps['Phi_cov'][i,chain_number,0,1]
    
    return ps

def get_point_estimates(ps0, param_name, estimate_name):
    ps = remove_cn_dimension(
        clean_samples(ps0)[param_name]
    )
    if estimate_name == 'mean':
        return np.mean(ps,axis=0)
    elif estimate_name == 'median':
        return np.median(ps, axis=0)

    
def get_credible_interval_beta(ps):
    df = get_post_df(ps)
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    df = df_quant.merge(df_quant2, on=['row', 'col'])

    dd = pd.DataFrame(data['beta'], columns=['0', '1'])
    dd['row'] = np.arange(dd.shape[0])
    dd = dd.melt(id_vars='row', var_name='col', value_name = 'data')
    dd['col'] = dd.col.astype(int)

    plot_data = df.merge(dd, on=['row', 'col'])
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )


def get_credible_interval_Phi(ps):
    df = get_post_df(ps)
    df = df[(df.row==0)&(df.col==1)]
    df_quant = df.groupby(['row', 'col'])[['value']].quantile(0.025).reset_index()
    df_quant.rename({'value':'q1'}, axis=1, inplace=True)
    df_quant2 = df.groupby(['row', 'col'])[['value']].quantile(0.975).reset_index()
    df_quant2.rename({'value':'q2'}, axis=1, inplace=True)

    plot_data = df_quant.merge(df_quant2, on=['row', 'col'])
    plot_data['data'] = data['Phi_cov'][0,1]
    plot_data
    plot_data['index'] = 'row ' + plot_data.row.astype(str)+' .col '+plot_data.col.astype(str)


    c1 = alt.Chart(plot_data).mark_bar(opacity=0.6).encode(
            alt.X('q1', title=None),
            alt.X2('q2', title=None))
    c1



    c2 = alt.Chart(plot_data).mark_point(opacity=1, color='red').encode(
            alt.X('data', title=None)
    )
    return (c1+c2
            ).facet(
                   'index',
                columns=2
                )

In [14]:
estimates = load_obj('beta_mean', log_dir).reshape(20, 1, 6, 2)
get_credible_interval_beta(estimates)

In [18]:
estimates = load_obj('beta_median', log_dir).reshape(20, 1, 6, 2)
get_credible_interval_beta(estimates)

In [25]:
estimates = load_obj('Phi_mean', log_dir).reshape(20, 1, 2, 2)
get_credible_interval_Phi(estimates)

In [26]:
estimates = load_obj('Phi_median', log_dir).reshape(20, 1, 2, 2)
get_credible_interval_Phi(estimates)