In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "true"



import numpy as onp
import jax.numpy as np
import matplotlib.pyplot as plt
from jax import random
from nrmifactors import algorithm as algo
from nrmifactors.state import State
import nrmifactors.priors as priors

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


In [None]:
import logging
logger = logging.getLogger("root")

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()

logger.addFilter(CheckTypesFilter())

In [None]:
from scipy.stats import skewnorm

ndata = 50
ngroups = 100

key = random.PRNGKey(202204)

locs = np.array([-2.0, 0.0, 2.0])
data = []
probs = []

for i in range(ngroups):
    key, subkey = random.split(key)
    probas = tfd.Dirichlet(np.array([0.5, 0.5, 0.5])).sample(seed=subkey)
    probs.append(probas)
    key, subkey = random.split(key)
    clus = tfd.Categorical(probs=probas).sample((ndata), seed=subkey)
    key, subkey = random.split(key)
    curr = tfd.Normal(locs[clus], np.ones_like(clus) * 1.5).sample(seed=subkey)
    data.append(curr)

data = np.stack(data)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20, 3))
axes[0].hist(onp.array(data[0, :]))    
axes[1].hist(onp.array(data[2, :]))
axes[2].hist(onp.array(data[5, :]))    
axes[3].hist(onp.array(data[-2, :]))

for i in range(4):
    axes[i].set_xlim(-6, 6)

In [None]:
from sklearn.cluster import KMeans

natoms = 20

km = KMeans(natoms)
km.fit(data.reshape(-1, 1))
clus = km.predict(data.reshape(-1,1)).reshape(data.shape)
means = km.cluster_centers_

init_atoms = np.hstack([means, np.ones_like(means) * 0.3])

In [None]:
key = random.PRNGKey(202204)
nlat = 20

prior = priors.NrmiFacPrior(
    kern_prior=priors.NNIGPrior(0.0, 0.01, 5.0, 5.0),
    lam_prior_iid=priors.GammaPrior(4.0, 4.0),
    lam_prior_mgp=priors.MGPPrior(50.0, 2.0, 3.0, 0, -0.05, 0.05),
    lam_prior="mgp",
    m_prior=priors.GammaPrior(2.0, 2.0),
    j_prior=priors.GammaPrior(2.0, 2.0)
)


lam = np.ones((ngroups, nlat)) / nlat
m = tfd.Gamma(0.1, 2.0).sample((nlat, natoms), seed=key).astype(float)

j = np.ones(natoms).astype(float) * 0.5
u = np.ones(ngroups).astype(float)

state = State(
    iter=0,
    atoms=init_atoms, 
    j=j, 
    lam=lam,
    phis=1.0/lam,
    deltas=np.ones(lam.shape[1]),
    m=m, 
    clus=clus, 
    u=u,
)
nan_idx = np.where(np.isnan(data))
nobs_by_group = np.array(
        [np.count_nonzero(~np.isnan(x)) for x in data]).astype(float)

In [None]:
key = random.PRNGKey(202204)

In [None]:
state, key = algo.adapt_mgp(state, 1000, 50, data, nan_idx, nobs_by_group, prior, key)

In [None]:
from copy import deepcopy

niter = 60000
nburn = 50000
thin = 1

states = []

for i in range(niter):
    print("\r{0}/{1}".format(i+1, niter), flush=True, end=" ")
    state, key = algo.run_one_step(state, data, nan_idx, nobs_by_group, prior, key)
    if (i > nburn) and (i % thin == 0):
        states.append(deepcopy(state))

In [None]:
import pickle

with open("simu2/chains_mgp3.pickle", "wb") as fp:
    pickle.dump(states, fp)

In [None]:
def eval_densities(xgrid, lam, m, j, atoms):
    weights = np.matmul(lam, m) * j
    weights /= weights.sum(axis=1)[:, np.newaxis]
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens = eval_comps[:, np.newaxis, :] * weights[np.newaxis, :, :]
    dens = np.sum(dens, axis=-1).T
    return dens

In [None]:
group_idx = [0, 1, 2, 4]

fig, axes = plt.subplots(nrows=1, ncols=len(group_idx), figsize=(20, 5))


idx = [-1, -10, -100, -1000]
xgrid = np.linspace(-10, 10, 1000)



for i in range(len(group_idx)):
    axes[i].hist(onp.array(data[group_idx[i], :]), density=True, alpha=0.3)

for j in sorted(idx):
    state = states[j]
    dens = eval_densities(xgrid, state.lam, state.m, state.j, state.atoms)
    for i in range(len(group_idx)):
        axes[i].plot(xgrid, dens[group_idx[i], :], label="j: {0}".format(j))
        axes[i].set_xlim(-15, 15)
        
axes[0].legend(fontsize=12)   
plt.tight_layout()
# plt.savefig("simu1/dens_estimate.pdf", bbox_inches="tight")        
plt.show()

In [None]:
nlat = states[-1].lam.shape[1]

fig, axes = plt.subplots(nrows=1, ncols=int(nlat), figsize=(20, 5))
axes = axes.flat

for j in sorted(idx):
    state = states[j]

    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T    


    for i in range(nlat):
        axes[i].plot(xgrid, dens_lat[i, :], label="j: {0}".format(len(states) + j))
        
axes[0].legend(fontsize=12)

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=nlat, figsize=(20,5))

axes = axes.flat

for i in range(ngroups):
    for j in range(nlat):
        axes[j].plot([x.lam[i, j] for x in states[-1000:]], label="j: {0}".format(j + 1))

# PostProcess

In [None]:
import pickle

# with open("simu2/chains_mgp2.pickle", "rb") as fp:
#    states = pickle.load(fp)

In [None]:
from jax import jit
from jax import jacfwd, grad
from nrmifactors.postprocess import ralm

delta = xgrid[1] - xgrid[0]

@jit
def obj_func(x, J, M, component_dens):
    curr_m = x @ M
    trans_dens = curr_m * J @ component_dens
    trans_dens /= (np.sum(trans_dens, axis=1, keepdims=True) * delta)
    out = 0.0
    for i in range(trans_dens.shape[0]):
        for j in range(i):
            curr = (np.sum((trans_dens[i, :] * trans_dens[j, :]) **2 ) * delta)
            out += curr
    return out

@jit
def obj_func2(x, J, M, component_dens):
    curr_m = x @ M
    trans_dens = curr_m * J @ component_dens
    trans_dens /= (np.sum(trans_dens, axis=1, keepdims=True) * delta)
    out = 0.0
    for i in range(trans_dens.shape[0]):
        for j in range(i):
            curr = np.sum(np.sqrt(trans_dens[i, :] * trans_dens[j, :])) * delta
            out += curr
    return out

@jit
def constraints(x):
    return - np.concatenate([
        (lam @ np.linalg.inv(x)).reshape(-1, 1),
        (x @ M).reshape(-1, 1)])[:, 0]

@jit
def max0(x):
    return x * (x > 0)

@jit
def penalty(x, lambdas):
    return 0.5 * 1.0 * np.sum(max0(lambdas / 0.5 + constraints(x))**2)


stepsize = 1e-6
init_thr = 1e-2
target_thr = 1e-6
min_lambda = 1e-4
init_lambdas = np.zeros_like(constr_eval) + 1
dmin = 1e-6
init_lambdas = np.zeros_like(constr_eval) + 0.01
max_lambda = 2
init_rho = 1
mu = 0.5


def get_opt_q(state, init_point):
    M = state.m
    lam = state.lam
    J = state.j 
    J /= np.sum(J)
    xgrid = np.linspace(-6, 6, 1000)

    component_dens = np.array([
        tfd.Normal(x[0], np.sqrt(x[1])).prob(xgrid) for x in states[0].atoms])
    
    f = lambda x: obj_func2(x, J, M, component_dens)
    grad_f = grad(f)
    grad_cons = jacfwd(constraints)
    
    opt_x_pen = ralm(
        f, grad_f, constraints, grad_cons, init_point, mu, stepsize, 
        init_thr, target_thr, init_lambdas, min_lambda, max_lambda, 
        init_rho, dmin, maxiter=100)
    return opt_x_pen

In [None]:
q0 = get_opt_q(states[0], np.eye(M.shape[0]))

In [None]:
state = states[0]
M = state.m
lam = state.lam
J = state.j 
J /= np.sum(J)
xgrid = np.linspace(-6, 6, 1000)

component_dens = np.array([
    tfd.Normal(x[0], np.sqrt(x[1])).prob(xgrid) for x in states[0].atoms])

obj_func2(q0, J, M, component_dens)

In [None]:
curr_m = q0 @ M
trans_dens = curr_m * J @ component_dens
trans_dens /= (np.sum(trans_dens, axis=1, keepdims=True) * delta)


In [None]:
qs = [get_opt_q(x, q0) for x in onp.array(states)[np.arange(0, 10000, 10)]]

In [None]:
with open("simu2/chains_mgp3_qs_newloss.pickle", "wb") as fp:
    pickle.dump(qs, fp)

In [None]:
idx = np.arange(1, 1000, 10)

fig, axes = plt.subplots(nrows=2, ncols=int(nlat), figsize=(20, 10))

for j in sorted(idx):
    state = onp.array(states)[np.arange(0, 10000, 10)][j]
    q = qs[j]

    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T   


    for i in range(nlat):
        d = dens_lat[i, :]
        d = d / np.sum(d * delta)
        axes[0][i].plot(xgrid, d, color="black", alpha=0.3)
    
    
    dens_lat = eval_comps[:, np.newaxis, :] * (q @ state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T   


    for i in range(nlat):
        d = dens_lat[i, :]
        d = d / np.sum(d * delta)
        if j == sorted(idx)[-1]:
            axes[1][i].plot(xgrid, d, color="red", lw=2)
        else:
            axes[1][i].plot(xgrid, d, color="black", alpha=0.3)

In [None]:
from nrmifactors.postprocess import optimal_align as align

In [None]:
state = states[-1]
eval_comps = tfd.Normal(
    loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    
dens_lat = eval_comps[:, np.newaxis, :] * (q @ state.m * state.j)[np.newaxis, :, :]
dens_lat = np.sum(dens_lat, axis=-1).T   
template_dens = dens_lat / (np.sum(dens_lat, axis=1, keepdims=True) * delta)

In [None]:
idx = np.arange(1, 1000, 10)

fig, axes = plt.subplots(nrows=2, ncols=int(nlat), figsize=(20, 10))

for j in sorted(idx)[:-1]:
    state = onp.array(states)[np.arange(0, 10000, 10)][j]
    q = qs[j]

    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    
    dens_lat = eval_comps[:, np.newaxis, :] * (q @ state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T  
    dens_lat = dens_lat / (np.sum(dens_lat, axis=1, keepdims=True) * delta)
    
    aligned_lat = align(template_dens, dens_lat)
    
    for i in range(nlat):
            axes[0][i].plot(xgrid, dens_lat[i, :], color="black", alpha=0.2)
            axes[1][i].plot(xgrid, aligned_lat[i, :], color="black", alpha=0.2)
            
            
for i in range(nlat):
    axes[0][i].plot(xgrid, template_dens[i, :], color="red", lw=3)
    axes[1][i].plot(xgrid, template_dens[i, :], color="red", lw=3)
    
# plt.savefig("../latex/images/simu_mgp_latent_dens.pdf", bbox_inches="tight")

In [None]:
idx = np.arange(1, 1000, 10)

fig, axes = plt.subplots(nrows=2, ncols=int(nlat), figsize=(20, 10))

for j in sorted(idx)[:-1]:
    state = onp.array(states)[np.arange(0, 10000, 10)][j]

    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(state.atoms[:, 1])).prob(xgrid[:, np.newaxis])
    
    j = state.j * 1e20
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T  
    
    norm_dens_lat = dens_lat / (np.sum(dens_lat, axis=1, keepdims=True) * delta)
    
    aligned_lat = align(template_dens, norm_dens_lat)
    
    for i in range(nlat):
            axes[0][i].plot(xgrid, dens_lat[i, :], color="black", alpha=0.2)
            axes[1][i].plot(xgrid, aligned_lat[i, :], color="black", alpha=0.2)
            
            
for i in range(nlat):
    axes[1][i].plot(xgrid, template_dens[i, :], color="red", lw=3)
    
#plt.savefig("../latex/images/simu_mgp_latent_dens.pdf", bbox_inches="tight")