In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as jnp
from jax import grad, vmap
import jax
from jax.tree_util import tree_map
import jax.random as jrandom
import matplotlib.pyplot as plt
from utils import prepare_data
from models import *
from advi import *

## Toy Regression

In [None]:
N = 100
d = 3
key = jrandom.key(42)
k1, k2, k3 = jrandom.split(key, 3)
X_data = jrandom.normal(k1, shape=(N, d))
w_star = jrandom.normal(k2, shape=(d,)) * 2
y = X_data @ w_star
linear_model = LinearModel(X_data, y, 1, 1)
linear_advi = mean_field_advi(linear_model)
key = jrandom.key(52)
loss = linear_advi.run_advi(
    key, 100, 1001, 1e-4, print_every=500, adaptive=True, alpha=0.5
)
print("Sample ", linear_advi.sample(k3)[0])
print("True value", w_star)

## HLR Experiments

In [None]:
k = jrandom.key(42)
beta_prior = 100
alpha_prior = 1
data = prepare_data()
hlr_model = HLR_Model(data)
hlr_advi = mean_field_advi(hlr_model)
loss = hlr_advi.run_advi(
    k, 10, 10001, 1e-5, print_every=100, adaptive=False, alpha=0.5
)

## Faces Experiments 

In [None]:
from scipy.io import loadmat

all_data = loadmat("data/frey_rawface.mat")
data = all_data["ff"].T[:100]

In [None]:
rank = 10
nmf_model = NMF_Model_PoissonGamma(data, rank, gamma_prior_shape=1, gamma_prior_scale=1)
nmf_dim = nmf_model.dim
k = jrandom.key(15)
trial_vec = jrandom.normal(k, shape=(nmf_dim,))
theta, beta = nmf_model.t_inv_map(trial_vec)
nmf_advi = mean_field_advi(nmf_model)
nmf_advi.run_advi(k, 10, 10001, 1e-5, print_every=500, adaptive=False)


thetas, betas = nmf_model.t_inv_map(nmf_advi.params["mu"])

u = 42
fig, axs = plt.subplots(2, 5, figsize=(12, 5))
axs = axs.flatten()
for i in range(10):
    axs[i].imshow(betas[i].reshape(28, 20), cmap="gray")
    axs[i].axis("off")
    axs[i].set_title(rf"$\beta_{{{i}}}$")

In [None]:
rank = 10
nmf_model = NMF_Model_PoissonDirExp(data, rank)
nmf_dim = nmf_model.dim
k = jrandom.key(15)
trial_vec = jrandom.normal(k, shape=(nmf_dim,))
theta, beta = nmf_model.t_inv_map(trial_vec)
nmf_advi = mean_field_advi(nmf_model)
nmf_advi.run_advi(k, 10, 10001, 1e-5, print_every=100, adaptive=True)


thetas, betas = nmf_model.t_inv_map(nmf_advi.params["mu"])

u = 42
fig, axs = plt.subplots(2, 5, figsize=(12, 5))
axs = axs.flatten()
for i in range(10):
    axs[i].imshow(betas[i].reshape(28, 20), cmap="gray")
    axs[i].axis("off")
    axs[i].set_title(rf"$\beta_{{{i}}}$")

In [None]:
u = 25
plt.figure()
plt.subplot(121)
plt.imshow((thetas[u] @ betas).reshape(28, 20), cmap="gray")
plt.axis("off")
plt.title(rf"$\theta_{{{u}}}\cdot\beta$")
plt.subplot(122)
plt.imshow((data[u]).reshape(28, 20), cmap="gray")
plt.title(rf"$Y_{{{u}, true}}$")
plt.axis("off")
print(f"theta_{u}: {thetas[u]}")