In [None]:
import pp_mix_cpp
import matplotlib.pyplot as plt
import numpy as np
import pickle
import arviz as az

from collections import Counter

from interface import Sampler, to_numpy, writeChains, loadChains

In [None]:
with open("data/child_origresp_interactions_missing.pickle", "rb") as fp:
    data = pickle.load(fp)
    
with open("data/prior_params_china.pickle", "rb") as fp:
    params = pickle.load(fp)

    
resps = data["resps"]
longcovs = data["longcovs"]
fixedcovs = data["fixedcovs"]

is_missing = data["is_missing"]

In [None]:
params

In [None]:
sampler = Sampler(50, "LinearDDP")
sampler.set_prior(
    nu=params["nu"],
    sigma0=params["sigma0"],
    gamma0=params["gamma0"],
    beta0=params["beta0"],
    varb=params["varb"],
    varg=params["varg"]
)

chains = sampler.run_mcmc(
    0, 100000, 10000, 10, resps, longcovs, fixedcovs, is_missing)

writeChains(chains, "chains/growth_linddp.recordio")

In [None]:
chains[-1].clus_allocs

In [None]:
n_clus_chain = []
for state in chains:
    uniq, cnts = np.unique(state.clus_allocs, return_counts=True)
    n_clus_chain.append(len(cnts[cnts > 10]))

In [None]:
import matplotlib.pyplot as plt


uniq, cnts = np.unique(n_clus_chain, return_counts=True)
plt.bar(uniq, cnts / np.sum(cnts))

In [None]:
from joblib import Parallel, delayed
from scipy.special import logsumexp
_LOG_2PI = np.log(2 * np.pi)
LOG_EPS = -10000

def gen_even_slices(n, n_packs, n_samples=None):
    start = 0
    if n_packs < 1:
        raise ValueError("gen_even_slices got n_packs=%s, must be >=1"
                         % n_packs)
    for pack_num in range(n_packs):
        this_n = n // n_packs
        if pack_num < n % n_packs:
            this_n += 1
        if this_n > 0:
            end = start + this_n
            if n_samples is not None:
                end = min(n_samples, end)
            yield slice(start, end, None)
            start = end

def mvn_lpdf(x, mean, prec_chol, prec_log_det, dim):
    dev = x - mean
    exp = - 0.5 * np.sum(np.square(np.dot(dev, prec_chol)), axis=-1)
    out = -0.5 * dim * _LOG_2PI + 0.5 * prec_log_det + exp
    return out


def get_joint_prec(tmax, Phi, Sigma):
    dim = Sigma.shape[0] 
    prec = np.zeros((tmax * dim, tmax * dim)) 
    I = np.eye(dim)
    sigma_inv = np.linalg.inv(Sigma)
#     diag_block = np.matmul((I + Phi).T, np.matmul(sigma_inv, I + Phi))  
#     offdiag_block = np.matmul(Phi.T, sigma_inv)
    
    diag_block = np.matmul((I + Phi).T, np.linalg.solve(Sigma, I + Phi))  
    offdiag_block = np.linalg.solve(Sigma.T, Phi).T
    for i in range(tmax-1):
        prec[i*dim:(i+1)*dim, i*dim:(i+1)*dim] = diag_block
        prec[i*dim:(i+1)*dim, (i+1)*dim:(i+2)*dim] = offdiag_block
        prec[(i+1)*dim:(i+2)*dim, i*dim:(i+1)*dim] = offdiag_block.T

    prec[(tmax-1)*dim:, (tmax-1)*dim:] = sigma_inv
    return 0.5 * (prec + prec.T)

def eval_joint_lpdf(B, Gamma, Regressor, Sigma, y, x, z):
    tmax = y.shape[0]
    mean = ((B * x[:tmax, 0]).T + np.matmul(Gamma, z)).ravel()
    phivec = np.matmul(Regressor, z)
    Phi = phivec.reshape(2, 2)
    prec = get_joint_prec(y.shape[0], Phi, Sigma)
    y = y.ravel()
    keep = np.where(y >= 0)[0]
    y = y[keep]
    mean = mean[keep]
    prec = prec[keep, :][:, keep]
    try:
        prec_chol = np.linalg.cholesky(prec)
        prec_logdet = 2 * np.sum(np.log(np.diag(prec_chol)))
        return mvn_lpdf(y, mean, prec_chol, prec_logdet, y.shape[0])
    except:
        return LOG_EPS


def eval_multiple_lpdf(Bchain, Gchain, Regchain_vec, SigmaChain, y, x, z):
    out = np.zeros((Bchain.shape[0], Regchain_vec.shape[1]))
    for i in range(Bchain.shape[0]):
        for j in range(Regchain_vec.shape[1]):
            regmat = Regchain_vec[i, j, :].reshape(4, 16)
            out[i, j] = eval_joint_lpdf(
                Bchain[i], Gchain[i], regmat, SigmaChain[i],
                y, x, z)
            
    return out

def eval_mixture_dens(chains, resps, longcovs, fixedcovs):
    betachain = np.stack([to_numpy(x.beta) for x in chains])
    gammachain = np.stack([to_numpy(x.gamma) for x in chains])
    sigmachain = np.stack([to_numpy(x.sigma) for x in chains])

    regchain = np.zeros(
        (len(chains), len(chains[0].lindpp_regressors), 
         chains[0].lindpp_regressors[0].size))
    for i in range(len(chains)):
        regchain[i, :, :] = np.stack([to_numpy(x) for x in chains[i].lindpp_regressors])

    fd = delayed(eval_multiple_lpdf)
    eval_normals = np.zeros((len(resps), len(chains), len(chains[0].lindpp_regressors)))
    for i in range(len(resps)):
        print("\r {0}/{1}".format(i+1, len(resps)), flush=True, end=" ")
        curr_dens = Parallel(n_jobs=6)(
            fd(betachain[s, :, :], gammachain[s, :, :], regchain[s, :, :],
              sigmachain[s, :, :], resps[i], longcovs[i], fixedcovs[i])
            for s in gen_even_slices(len(chains), 6))
        eval_normals[i, :, :] = np.vstack(curr_dens)
        #eval_normals[i, :, :] = eval_multiple_lpdf(
        #    betachain, gammachain, regchain,
        #    sigmachain, resps[i], longcovs[i], fixedcovs[i]
        #)

    return eval_normals
    
def eval_ldpp_dens(chains, resps, longcovs, fixedcovs):
    weightschain = np.vstack([to_numpy(x.dp_weights) for x in chains])
    eval_normals = eval_mixture_dens(chains, resps, longcovs, fixedcovs)
    out = logsumexp(eval_normals + np.log(weightschain), axis=-1)
    return out


def lpml(log_densities):
    inv_cpos = np.mean(1.0 / np.exp(log_densities), axis=0)
    return np.sum(-np.log(inv_cpos))


def waic(log_densities):
    # log densities: nsamples x ndata
    print(log_densities.shape)
    log_pred_dens = logsumexp(log_densities, axis=0) - np.log(log_densities.shape[0])
    # log_pred_dens: ndata
    print(log_pred_dens.shape)
    lpd = np.sum(log_pred_dens)
    p_waic = np.sum(np.var(log_pred_dens), axis=0)
    return lpd - p_waic

In [None]:
ldpp_densities = eval_ldpp_dens(chains, resps, longcovs, fixedcovs)

In [None]:
chains[0].lindpp_regressors[0].size

In [None]:
lpml(ldpp_densities)

In [None]:
waic(ldpp_densities)

In [None]:
np.log(1e-200)

In [None]:
lsb_chains = loadChains("data/")