In [None]:
import os
os.chdir("../")
os.environ["JAX_ENABLE_X64"] = "true"

import pickle
import matplotlib.pyplot as plt
import jax.numpy as np

In [None]:
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

def eval_densities(xgrid, lam, m, j, atoms):
    weights = lam @ (m * j)
    weights /= weights.sum(axis=1)[:, np.newaxis]
    vars = np.array(atoms[:, 1])
    #vars = vars.at[vars < 0.15].set(0.15)
    eval_comps = tfd.Normal(loc=atoms[:, 0], scale=np.sqrt(vars)).prob(xgrid[:, np.newaxis])
    dens = eval_comps[:, np.newaxis, :] * weights[np.newaxis, :, :]
    dens = np.sum(dens, axis=-1).T
    return dens

In [None]:
with open("invalsi/chains_mgp3.pickle", "rb") as fp:
    states = pickle.load(fp)
    
    
with open("invalsi/math_grades.pickle", "rb") as fp:
    data = pickle.load(fp)

In [None]:
nlat = states[-1].m.shape[0]

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=nlat, figsize=(12, 3))

for j in np.arange(0, len(states), 20):
    state = states[j]
    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(vars)).prob(xgrid[:, np.newaxis])
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T    


    for i in range(nlat):
        axes[i].plot(xgrid, dens_lat[i, :], color="black", lw=2, alpha=0.3)
        
#axes[0].legend(fontsize=12)

plt.tight_layout()

#plt.savefig("../latex/images/invalsi_latent_draws.pdf", bbox_inches="tight")        
plt.show()

In [None]:
avg_lat_dens = np.zeros((nlat, len(xgrid)))

for state in states[-500:]:
    vars = np.array(state.atoms[:, 1])
    vars = vars.at[vars < 0.15].set(0.15)
    eval_comps = tfd.Normal(
        loc=state.atoms[:, 0], scale=np.sqrt(vars)).prob(xgrid[:, np.newaxis])
    
    dens_lat = eval_comps[:, np.newaxis, :] * (state.m * state.j)[np.newaxis, :, :]
    dens_lat = np.sum(dens_lat, axis=-1).T  
    avg_lat_dens += dens_lat

avg_lat_dens /= len(states)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=nlat, figsize=(15, 3))


for i in range(nlat):
    axes[i].plot(xgrid, avg_lat_dens[i, :], label="j: {0}".format(len(states) + j))
        
#axes[0].legend(fontsize=12)

plt.tight_layout()

# plt.savefig("invalsi/avg_latent_dens.pdf", bbox_inches="tight")        
plt.show()

In [None]:
# Postprocess

def discretize_dens(dens, xgrid):
    delta = xgrid[1] - xgrid[0]
    out = np.zeros_like(xgrid)
    for j in np.arange(2, 11):
        wh = np.where((xgrid > j) & (xgrid < j+1))
        d = np.sum(dens[wh]) * delta
        out = out.at[wh].set(d)
    return out


M = states[-1].m
lam = states[-1].lam
J = states[-1].j


vars = np.array(states[-1].atoms[:, 1])
vars = vars.at[vars < 0.2].set(0.2)
component_dens = tfd.Normal(
    loc=states[-1].atoms[:, 0], scale=np.sqrt(vars)).prob(xgrid[:, np.newaxis]).T

#component_dens = np.array([
#    tfd.Normal(x[0], np.sqrt(x[1])).prob(xgrid) for x in states[-1].atoms])

#discretized_dens = np.stack([
#    discretize_dens(x, xgrid) for x in component_dens  
#])

delta = xgrid[1] - xgrid[0]


def obj_func(x):
    curr_m = x @ M
    trans_dens = (curr_m * J) @ component_dens
    trans_dens /= (np.sum(trans_dens, axis=1, keepdims=True) * delta)
    out = 0.0
    for i in range(trans_dens.shape[0]):
        for j in range(i):
            curr = (np.sum((trans_dens[i, :] * trans_dens[j, :])) * delta)**2  
            out += curr
            # print("i: {0}, j: {1}, curr: {2}".format(i, j, curr))
    return out


def constraints(x):
    return - np.concatenate([
        (lam @ np.linalg.inv(x)).reshape(-1, 1),
        (x @ M).reshape(-1, 1)])[:, 0]

def max0(x):
    return x * (x > 0)

def penalty(x, lambdas):
    return 0.5 * 1.0 * np.sum(max0(lambdas / 0.5 + constraints(x))**2)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(15,3))

for j in range(M.shape[0]):
    axes[j].plot(xgrid,  (M @ component_dens)[j, :])
    axes[j].set_title("mu_{0}".format(j+1), fontsize=16)
    
    
#plt.savefig("invalsi/avg_mu.pdf", bbox_inches="tight")

In [None]:
from jax import grad


grad_f = grad(obj_func)
f = obj_func

In [None]:
from nrmifactors.postprocess import dissipative_lie_rattle_fast

x0 = np.eye(M.shape[0])
opt_x, niter = dissipative_lie_rattle_fast(f, grad_f, x0, 0.9, 1e-5, 1e-6, maxiter=10000)
opt_x

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(15,3))

q = opt_x

for j in range(M.shape[0]):
    axes[j].plot(xgrid,  (q @ M @ component_dens)[j, :])
    axes[j].set_title("mu_{0}".format(j+1), fontsize=16)
    
#plt.savefig("invalsi/opt_mu_unc.pdf", bbox_inches="tight")

In [None]:
from jax import jacfwd
from nrmifactors.postprocess import ralm

constr_eval = constraints(x0)
grad_cons = jacfwd(constraints)

mu = 0.1
stepsize = 1e-6
init_thr = 1e-2
target_thr = 1e-6
min_lambda = 1e-4
init_lambdas = np.zeros_like(constr_eval) + 0.1
max_lambda = 30
init_rho = 10
dmin = 1e-6


opt_x_pen = ralm(
    obj_func, grad_f, constraints, grad_cons, opt_x, mu, stepsize, 
    init_thr, target_thr, init_lambdas, min_lambda, max_lambda, 
    init_rho, dmin, maxiter=1000)

In [None]:
opt_x_pen

In [None]:
from scipy.integrate import trapz


def discretize_dens(dens, xgrid):
    out = []
    wh = np.where(xgrid < 1.5)
    out.append(trapz(dens[wh], xgrid[wh]))
    for i in range(2, 10):
        wh = np.where((xgrid >= i - 0.5) & (xgrid < i + 0.5))
        out.append(trapz(dens[wh], xgrid[wh]))
    
    wh = np.where(xgrid > 9.5)
    out.append(trapz(dens[wh], xgrid[wh]))
    return out

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(15,3))

q = opt_x_pen

for j in range(M.shape[0]):
    d = (q @ (M * J) @ component_dens )[j, :]
    d = d / np.sum(d * (xgrid[1] - xgrid[0]))
    axes[j].plot(xgrid,  d)
    axes[j].set_ylim((0.0, 0.5))

plt.tight_layout()

#plt.savefig("../latex/images/invalsi_latent_factors.pdf", bbox_inches="tight")
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=M.shape[0], figsize=(15,3))

q = opt_x_pen

for j in range(M.shape[0]):
    d = (q @ (M * J) @ component_dens )[j, :]
    d = d / np.sum(d * (xgrid[1] - xgrid[0]))
    bars = np.arange(1, 11)
    heights = discretize_dens(d, xgrid)
    axes[j].bar(bars,  heights)
    axes[j].set_xticks(bars)
    axes[j].set_ylim((0.0, 0.42))

plt.tight_layout()
#plt.savefig("../latex/images/invalsi_latent_factors_discrete.pdf", bbox_inches="tight")
plt.show()

In [None]:
post_lam = lam @ np.linalg.inv(opt_x_pen)

In [None]:
masses = np.sum(q @ M @ component_dens * delta, axis=1)

In [None]:
lambda_trans = post_lam * masses
lambda_trans /= np.sum(lambda_trans, axis=1,keepdims=True)

In [None]:
np.sum(lambda_trans, axis=0)

# cluster based on the factor scores

In [None]:
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
import numpy as onp


def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = onp.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)

In [None]:
model = AgglomerativeClustering(linkage="complete", distance_threshold=0, n_clusters=None)

X = onp.array(lambda_trans)
model.fit(X)

In [None]:
fig = plt.figure(figsize=(6, 3))

#plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="level")
plt.xticks([])
#plt.savefig("../latex/images/invalsi_hclust_complete.pdf", bbox_inches="tight")
plt.show()

In [None]:
from scipy import cluster

X = onp.array(lambda_trans)
Z = cluster.hierarchy.complete(X)
cutree = cluster.hierarchy.cut_tree(Z, n_clusters=4)[:, 0]

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12,3))
axes = axes.flat

latent_dens = q @ M @ component_dens

for l in onp.array(np.unique(cutree)):
    idxs = np.where(cutree == l)[0]
    avg_lam = np.mean(lambda_trans[idxs, :], axis=0)
    dens = avg_lam @ latent_dens
    dens /= np.sum(dens * delta)
    axes[l].plot(xgrid, dens)
    axes[l].set_ylim((0.0, 0.35))
    
plt.tight_layout()
#plt.savefig("../latex/images/invalsi_cluster_dens.pdf")

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12,3))
axes = axes.flat

latent_dens = q @ M @ component_dens

for l in onp.array(np.unique(cutree)):
    idxs = np.where(cutree == l)[0]
    avg_lam = np.mean(lambda_trans[idxs, :], axis=0)
    dens = avg_lam @ latent_dens
    dens /= np.sum(dens * delta)
    bars = np.arange(1, 11)
    heights = discretize_dens(dens, xgrid)
    axes[l].bar(bars,  heights)
    axes[l].set_xticks(bars)
    axes[l].set_ylim((0.0, 0.32))
    
plt.tight_layout()
#plt.savefig("../latex/images/invalsi_cluster_dens_discrete.pdf")