In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "true"

os.chdir("../")

In [None]:
import jax.numpy as np
import numpy as onp
import matplotlib.pyplot as plt
import pickle

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

In [None]:
def eval_densities(xgrids, state):
    lam = state.lam
    m = state.m
    j = state.j
    atoms = state.atoms
    weights = np.matmul(lam, m) * j
    weights /= weights.sum(axis=1)[:, np.newaxis]
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrids[:, :, np.newaxis])
    dens = eval_comps[:, :, :] * weights[:, np.newaxis, :]
    dens = np.nansum(dens, axis=-1)
    return dens

In [None]:
with open("income_data/california_income_subsampled.pickle", "rb") as fp:
    data = pickle.load(fp).astype(np.float64)
    data = np.array(data)
    data = data.at[data < 0].set(np.nan)
    data = np.log(data)

In [None]:
with open("income_data/california_mcmc_out_lat6.pickle", "rb") as fp:
    tmp = pickle.load(fp)
    states = tmp["states"]

In [None]:
xgrids = np.stack([np.linspace(0, 16, 1000)] * data.shape[0])
dens = eval_densities(xgrids, states[-1])

## model comparison

In [None]:
from jax.scipy.special import logsumexp
from jax import jit, vmap
from functools import partial


@jit
def compute_waic(lpdf):
    elpd = np.sum(logsumexp(lpdf, axis=0) - np.log(lpdf.shape[0]))
    p_waic = np.sum(np.var(lpdf, axis=0))
    return elpd - p_waic


@partial(jit, static_argnums=(1,))
def eval_densities(xgrids, ngroups, lam, m, j, atoms):
    weights = np.matmul(lam, m) * j
    weights /= weights.sum(axis=1)[:, np.newaxis]
    dens = []
    for i in range(ngroups):
        eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrids[i, :, np.newaxis])
        curr_dens = eval_comps * weights[i, :]
        dens.append(np.sum(curr_dens, axis=-1))
    return np.stack(dens)

In [None]:
base_file = "income_data/california_mcmc_out_lat{0}.pickle"
thinned = np.arange(0, 9999, 5)

In [None]:
base_file = "income_data/california_mcmc_out_lat{0}.pickle"
waics = []
thinned = np.arange(0, 9999, 5)

keep_cols = np.where(~np.isnan(data.reshape(-1, 1)))[0]

for nlat in [2, 4, 6, 8, 10]:
    fname = base_file.format(nlat)
    with open(fname, "rb") as fp:
        tmp = pickle.load(fp)
        states = tmp["states"]
        
    lam_chain = np.stack([x.lam for x in states])[thinned, :, :]
    m_chain = np.stack([x.m for x in states])[thinned, :, :]
    j_chain = np.stack([x.j for x in states])[thinned, :]
    atom_chain = np.stack([x.atoms for x in states])[thinned, :, :]
        
    dens = vmap(lambda x, y, z, v: eval_densities(data, data.shape[0], x, y, z, v))(
        lam_chain, m_chain, j_chain, atom_chain) 
    dens = dens.reshape(dens.shape[0], dens.shape[1] *  dens.shape[2])
    dens = dens[:, keep_cols]
    print("dens.shape: ", dens.shape)
    waic = compute_waic(np.log(dens))
    waics.append(waic)
    print("nlat: {0}, waic: {1:.4f}".format(nlat, waic))

In [None]:
plt.plot([2, 4, 6, 8, 10], waics)

In [None]:
means = np.nanmean(data, axis=1)
max_m_ind = np.argmax(means)
min_m_ind = np.argmin(means)
var = np.nanvar(data, axis=1)
max_v_ind = np.argmax(var)
min_v_ind = np.argmin(var)

In [None]:
inds = (max_m_ind, min_m_ind, max_v_ind, min_v_ind)

In [None]:
with open(base_file.format(4), "rb") as fp:
    tmp = pickle.load(fp)
    states = tmp["states"]

lam_chain = np.stack([x.lam for x in states])[thinned, :, :]
m_chain = np.stack([x.m for x in states])[thinned, :, :]
j_chain = np.stack([x.j for x in states])[thinned, :]
atom_chain = np.stack([x.atoms for x in states])[thinned, :, :]

xgrids = np.stack([np.linspace(0, 16, 1000)] * data.shape[0])

dens = vmap(lambda x, y, z, v: eval_densities(xgrids, xgrids.shape[0], x, y, z, v))(
        lam_chain, m_chain, j_chain, atom_chain) 

In [None]:
nlat = states[-1].lam.shape[1]

fig, axes = plt.subplots(nrows=1, ncols=states[0].m.shape[0], figsize=(12,3))
axes = axes.flat

xgrid = xgrids[0, :]

for j in np.arange(0, 10000, 100):
    state = states[j]

    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T    


    for i in range(nlat):
        axes[i].plot(xgrid, dens_lat[i, :], label="j: {0}".format(j), color="black", lw=2, alpha=0.3)


plt.tight_layout()
# plt.savefig("../latex/images/income_latent_draws.pdf", bbox_inches="tight")
plt.show()

In [None]:
M = states[-1].m
lam = states[-1].lam

component_dens = np.array([
    tfd.Normal(x[0], np.sqrt(x[1])).prob(xgrid) for x in states[-1].atoms])
delta = xgrid[1] - xgrid[0]


@jit
def obj_func(x):
    curr_m = x @ M
    trans_dens = curr_m @ component_dens
    trans_dens /= (np.sum(trans_dens, axis=1, keepdims=True) * delta)
    out = 0.0
    for i in range(trans_dens.shape[0]):
        for j in range(i):
            curr = (np.sum((trans_dens[i, :] * trans_dens[j, :])) * delta)**2  
            out += curr
    return out

@jit
def constraints(x):
    return - np.concatenate([
        (lam @ np.linalg.inv(x)).reshape(-1, 1),
        (x @ M).reshape(-1, 1)])[:, 0]

@jit
def max0(x):
    return x * (x > 0)

@jit
def penalty(x, lambdas):
    return 0.5 * 1.0 * np.sum(max0(lambdas / 0.5 + constraints(x))**2)

In [None]:
from jax import jacfwd, grad
from nrmifactors.postprocess import ralm

x0 = np.eye(M.shape[0])

grad_f = grad(obj_func)
f = obj_func
constr_eval = constraints(x0)
grad_cons = jacfwd(constraints)

mu = 0.9
stepsize = 1e-6
init_thr = 1e-2
target_thr = 1e-6
min_lambda = 1e-4
init_lambdas = np.zeros_like(constr_eval) + 10
max_lambda = 1000
init_rho = 10
dmin = 1e-6


def get_opt_q(state, init_point):
    M = state.m
    lam = state.lam
    J = state.j 
    J /= np.sum(J)
    xgrid = np.linspace(-6, 6, 1000)

    component_dens = np.array([
        tfd.Normal(x[0], np.sqrt(x[1])).prob(xgrid) for x in states[0].atoms])
    
    f = lambda x: obj_func(x, J, M, component_dens)
    grad_f = grad(f)
    grad_cons = jacfwd(constraints)
    
    opt_x_pen = ralm(
        f, grad_f, constraints, grad_cons, init_point, mu, stepsize, 
        init_thr, target_thr, init_lambdas, min_lambda, max_lambda, 
        init_rho, dmin, maxiter=100)
    return opt_x_pen

In [None]:
from nrmifactors.postprocess import dissipative_lie_rattle_fast

x0 = np.eye(M.shape[0])
opt_x, niter = dissipative_lie_rattle_fast(f, grad_f, x0, 0.9, 1e-5, 1e-6, maxiter=1000)

In [None]:
q0 = get_opt_q(states[0], x0)
qs = [get_opt_q(x, q0) for x in onp.array(states)[np.arange(0, 10000, 10)]]

In [None]:
from scipy.integrate import trapz


fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(12,3))

q = q0

wh = np.where((xgrid > 5) & (xgrid < 14))[0]
latent_mass = []

means = []

for j in range(M.shape[0]):
    dens =  (q @ M @ component_dens)[j, wh]
    latent_mass.append(np.sum(dens * delta))
    dens = dens / np.sum(dens * delta)
    means.append(trapz(xgrid[wh] * dens, xgrid[wh]))
    axes[j].plot(xgrid[wh], dens)
plt.tight_layout()
# plt.savefig("../latex/images/income_latent_factors.pdf", bbox_inches="tight")

In [None]:
latent_dens = M @ component_dens
all_dens = state.lam @latent_dens
all_dens = all_dens / np.sum(all_dens * delta, axis=1, keepdims=True)
mean_dens = np.mean(all_dens, axis=0) 

In [None]:
plt.plot(xgrid, mean_dens)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(12,3))

q = q0

wh = np.where((xgrid > 5) & (xgrid < 14))[0]
latent_mass = []

for j in range(M.shape[0]):
    dens =  (q @ M @ component_dens)[j, wh]
    latent_mass.append(np.sum(dens * delta))
    dens = dens / np.sum(dens * delta)
    axes[j].plot(xgrid[wh], dens - mean_dens[wh])
    axes[j].set_ylim(-0.08, 0.11)
plt.tight_layout()

plt.savefig("../latex/images/income_latent_factors_diff.pdf", bbox_inches="tight")

In [None]:
import numpy as onp

opt_lambda = lam @ np.linalg.inv(q) * np.array(latent_mass)
opt_lambda = onp.array(opt_lambda)

with open("income_data/opt_lambda_scaled.pickle", "wb") as fp:
    pickle.dump(opt_lambda, fp)