# Recognition-parametrized Variational autoencoders
- $p_\theta(\mathcal{X},\mathcal{Z})$ is a conditionally normalized RPM, whereas $q_\psi(\mathcal{Z} | \mathcal{X})$ is from a jointly normalized RPM

- all RPMS conditionally independent !

- here application to the Poisson bouncing balls from the AISTATS RPM paper, except that we use a 50-dim. Gaussian instead of the Gaussian process used therein.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from utils_data_external import linear_regression_1D_latent as regLatent
from utils_data_external import plot_poisson_balls
import matplotlib.pyplot as plt

from exps import init_gaussian_rpm

from rpm import RPMEmpiricalMarginals


N = 100       
rpm_variant = 'amortized'
amortize_ivi = 'full'
model_seed = 0

identifier = rpm_variant + '_' + amortize_ivi + '_N_' + str(N) + '_seed_' + str(model_seed)
root = os.curdir
res_dir = 'fits'
fn_base = os.path.join(res_dir, identifier, identifier)

data = torch.tensor(np.load(fn_base + '_data.npy'))
true_latent_ext = torch.tensor(np.load(fn_base + '_latents.npy'))

exp_dict = np.load(fn_base + '_exp_dict.npz', allow_pickle=True)['arr_0'].tolist()
N,J,K,T = exp_dict['N'],exp_dict['J'],exp_dict['K'],exp_dict['T']
init_rb_bandwidth = exp_dict['init_rb_bandwidth']
ls = np.load(fn_base + '_loss.npy')

xjs = [data[:,j] for j in range(J)]
pxjs = RPMEmpiricalMarginals(xjs)
observations = (torch.stack(xjs, dim=-1),)


obs_locs = torch.linspace(0,1,T).reshape(-1,1)
model = init_gaussian_rpm(N, J, K, T, pxjs,
                        init_rb_bandwidth, obs_locs,
                        rpm_variant, amortize_ivi
                       )
model.load_state_dict(torch.load(fn_base + '_rpm_state_dict'))


prior = model.joint_model[1]
eta_0 = prior.nat_param
if rpm_variant in ['amortized', 'temporal']:    
    eta_q, _ = model.comp_eta_q(xjs, eta_0)
else: 
    eta_q = model.comp_eta_q(xjs, idx_data=np.arange(N), eta_0=eta_0)
EqtZ = prior.log_partition.nat2meanparam(eta_q)

mu = EqtZ[:,:T]
sig2 = torch.diagonal(EqtZ[:,T:].reshape(-1,T,T),dim1=-2,dim2=-1) - mu**2

latent_true, latent_mean_fit, latent_variance_fit, R2 = regLatent(
    latent_true = true_latent_ext,
    latent_mean_fit = mu.unsqueeze(-1), 
    latent_variance_fit = sig2)

plt.plot(ls)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

plot_poisson_balls(observations, 
                   obs_locs=obs_locs.squeeze(-1), 
                   latent_mean_fit=latent_mean_fit.squeeze(-1), 
                   latent_variance_fit=latent_variance_fit)
