In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
os.environ["xla_cpu_multi_thread_eigen"] = "true"
os.environ["intra_op_parallelism_threads"] = "4"


import jax.numpy as np
import matplotlib.pyplot as plt
from jax import random
from jax.experimental.sparse import COO
from nrmifactors import algorithm as algo
from nrmifactors.state import State
import nrmifactors.priors as priors
import pickle

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = random.PRNGKey(0)



In [2]:
from nrmifactors.algorithm import update_tau_gmrf, update_lambda_gmrf

In [3]:
with open("income_data/california_puma_neighbors.pickle", "rb") as fp:
    neighbors = pickle.load(fp)

In [None]:
prec = np.diag(neighbors.sum(axis=1)) - neighbors
eigvals, eigvecs = np.linalg.eigh(prec)
prec_logdet = np.sum(np.log(eigvals[eigvals > 1e-6]))
prec = COO.fromdense(prec)

In [None]:
nlat = 10
true_tau = 5

fake_prec = true_tau * (np.diag(neighbors.sum(axis=1)) - 0.95 * neighbors)
cov = np.linalg.inv(fake_prec)

log_lam = tfd.MultivariateNormalFullCovariance(
    np.zeros(cov.shape[0]), cov).sample(sample_shape=(nlat,), seed=key).T
lam = np.exp(log_lam)

In [None]:
lam = lam * 20

In [None]:
natoms = 10
ndata = 500
ngr = lam.shape[0]
m = np.ones((nlat, natoms))
j = np.ones(natoms)

weights = (lam @ m) * j
u = tfd.Gamma(np.ones(ngr) * ndata, np.sum(weights, axis=1)).sample(seed=key)
weights = weights / np.sum(weights, axis=1).reshape(-1, 1)

In [None]:
clus_allocs = tfd.Categorical(probs=weights).sample(sample_shape=(500,), seed=key).T

In [None]:
mcmc_lam = np.ones(lam.shape) * 20
lam_chain = []

for i in range(100):
    print("\r{0}/{1}".format(i+1, 1000), flush=True, end=" ")
    mcmc_lam, key, step_size = update_lambda_gmrf(
        clus_allocs, mcmc_lam, m, j, u, prec, prec_logdet, true_tau, key)
    lam_chain.append(mcmc_lam)

In [None]:
def eval_densities(xgrid, lam, m, j, atoms):
    weights = np.matmul(lam, m) * j
    weights /= weights.sum(axis=1)[:, np.newaxis]
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens = eval_comps[:, np.newaxis, :] * weights[np.newaxis, :, :]
    print(dens.shape)
    dens = np.sum(dens, axis=-1).T
    return dens

In [None]:
lam 

In [None]:
xgrid = np.linspace(-15, 15, 1000)
atoms = np.hstack([np.linspace(-10, 10, 10).reshape(-1, 1), np.ones(10).reshape(-1, 1) * 0.3])
true_dens = eval_densities(xgrid, lam, m, j, atoms)
plt.plot(xgrid, true_dens[0, :])

for i in [-50, -5, -1]:
    est_dens = eval_densities(xgrid, lam_chain[i], m, j, atoms)  
    plt.plot(xgrid, est_dens[-20, :])

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(20, 20))
axes = axes.flat

for j in range(nlat):
    for i in range(5):
        axes[j].plot([x[i, j] for x in lam_chain])

In [4]:
import numpy as onp
import pandas as pd

def get_weights(Nx, Ny):

    N = Nx*Ny
    centers = onp.zeros((N, 2))
    for i in range(Nx):
        for j in range(Ny):
            centers[i + j*Nx, :] = np.array([i + 0.5, j + 0.5])
    c = 0.3
    alpha1 = c
    alpha2 = -c
    beta1 = c
    beta2 = -c

    weights = []
    mean_centers = np.mean(centers, axis=0)
    for center in centers:
        w1 = alpha1 * (center[0] - mean_centers[0]) \
             + beta1 * (center[1] - mean_centers[1])
        w2 = alpha2 * (center[0] - mean_centers[0]) \
            + beta2 * (center[1] - mean_centers[1])
        weights.append(inv_alr([w1, w2]))

    return np.array(weights)

def inv_alr(x):
    out = onp.exp(np.hstack((x, 0)))
    return np.array(out / np.sum(out))


def simulate_from_mixture(weights):
    means = [-5, 0, 5]
    comp = onp.random.choice(3, p=weights)
    return onp.random.normal(loc=means[comp], scale=1)


def simulate_data(weights, numSamples):
    data = []
    for i in range(len(weights)):
        for j in range(numSamples):
            data.append([i, simulate_from_mixture(weights[i])])
    return pd.DataFrame(data, columns=["group", "datum"])


def compute_G(Nx, Ny):
    N = Nx*Ny
    G = onp.diag(np.ones(N-1), 1) + onp.diag(np.ones(N-1), -1) +\
        onp.diag(np.ones(N-Nx), Nx) + onp.diag(np.ones(N-Nx), -Nx)
    # tolgo i bordi
    border_indices = Nx*np.arange(1, Ny)
    G[border_indices, border_indices - 1] = 0
    G[border_indices - 1, border_indices] = 0

    return np.array(G)


def eval_densities(xgrid, lam, m, j, atoms):
    weights = np.matmul(lam, m) * j
    weights /= weights.sum(axis=1)[:, np.newaxis]
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens = eval_comps[:, np.newaxis, :] * weights[np.newaxis, :, :]
    dens = np.sum(dens, axis=-1).T
    return dens


def get_true_dens(xgrid, weights, atoms):
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(atoms[:, 1])).prob(xgrid[:, np.newaxis])
    dens = eval_comps[:, np.newaxis, :] * weights[np.newaxis, :, :]
    dens = np.sum(dens, axis=-1).T
    return dens

In [5]:
nx = 16
ngroups = nx**2
W = compute_G(nx, nx)

weights = get_weights(nx, nx)
datas = simulate_data(weights, 50)

# first our model, in parallel
groupedData = []
for g in range(ngroups):
    groupedData.append(datas[datas['group'] == g]['datum'].values)

In [6]:
data = np.stack(groupedData)

prec = np.diag(W.sum(axis=1)) - W
eigvals, eigvecs = np.linalg.eigh(prec)
prec_logdet = np.sum(np.log(eigvals[eigvals > 1e-6]))
prec = COO.fromdense(prec)

In [7]:
from sklearn.cluster import KMeans

natoms = 10
nlat = 5


km = KMeans(natoms)
km.fit(data.reshape(-1, 1))
clus = km.predict(data.reshape(-1,1)).reshape(data.shape)
means = km.cluster_centers_

init_atoms = np.hstack([means, np.ones_like(means) * 0.3])

In [8]:
prior = priors.NrmiFacPrior(
    kern_prior=priors.NNIGPrior(0.0, 0.01, 5.0, 5.0),
    lam_prior_gmrf=priors.GMRFPrior(sigma=prec, sigma_logdet=prec_logdet, tau_a=2, tau_b=2),
    lam_prior="gmrf",
    m_prior=priors.GammaPrior(2.0, 2.0),
    j_prior=priors.GammaPrior(2.0, 2.0))


lam = np.ones((data.shape[0], nlat))
m = tfd.Gamma(prior.m_prior.a, prior.m_prior.b).sample((nlat, natoms), seed=key).astype(float)

j = np.ones(natoms).astype(float) * 0.5
u = np.ones(ngroups).astype(float)

#clus = tfd.Categorical(probs=np.ones(natoms)/natoms).sample(data.shape, seed=key)
state = State(
    iter=0,
    atoms=init_atoms, 
    j=j, 
    lam=lam,
    m=m, 
    clus=clus, 
    u=u,
    tau=0.5
)

In [9]:
from jax.ops import index_update, index

nan_idx = index[np.isnan(data)]
nobs_by_group = np.array(
        [np.count_nonzero(~np.isnan(x)) for x in data]).astype(float)

In [10]:
from copy import deepcopy

niter = 10000
nburn = 5000
thin = 1

states = [deepcopy(state)]
for i in range(niter):
    print("\r{0}/{1}".format(i+1, niter), flush=True, end=" ")
    state, key = algo.run_one_step(state, data, nan_idx, nobs_by_group, prior, key)
    if (i > nburn) and (i % thin == 0):
        states.append(deepcopy(state))

1/10000 

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/opt/homebrew/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-d9868d31db80>", line 10, in <module>
    state, key = algo.run_one_step(state, data, nan_idx, nobs_by_group, prior, key)
  File "/Users/marioberaha/research/bnp/nrmi_factor_models/nrmifactors/nrmifactors/algorithm.py", line 387, in run_one_step
    state.atoms, rng_key = update_atoms(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/api.py", line 424, in cache_miss
    out_flat = xla.xla_call(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 1551, in call

TypeError: object of type 'NoneType' has no len()

In [None]:
true_atoms = np.hstack([
    np.array([-5, 0, 5]).reshape(-1,1),
    np.ones((3, 1))]
)

In [None]:
area_idx = np.array([1, 2 * 3, 90, -1])
chain_idx = np.array([1, 50, -1])
xgrid = np.linspace(-10, 10, 1000)

fig, axes = plt.subplots(nrows=1, ncols=len(area_idx), figsize=(15, 5))

true_dens = get_true_dens(xgrid, weights[area_idx, :], true_atoms)

for c in chain_idx:
    curr_lam = states[c].lam[area_idx, :]
    est_dens = eval_densities(xgrid, curr_lam, states[c].m, states[c].j, states[c].atoms)
    for i in range(len(area_idx)):
        axes[i].plot(xgrid, est_dens[i, :], label="Iter {0}".format(c))
        
for i in range(len(area_idx)):
        axes[i].plot(xgrid, true_dens[i, :], label="True Dens")
        
axes[0].legend()
plt.show()

In [None]:
weights[90, :]