# Recognition-parametrized Variational autoencoders
- $p_\theta(\mathcal{X},\mathcal{Z})$ is a conditionally normalized RPM, whereas $q_\psi(\mathcal{Z} | \mathcal{X})$ is from a jointly normalized RPM

- all RPMS conditionally independent !

- here application to to a 2D Gaussian copula with exponential marginals. $p_\theta(\mathcal{Z})$ and $f_{\theta_j}(\mathcal{Z}| \bf{x}_j)$ are Gaussian for each $j =1, 2$.

In [None]:
%load_ext autoreload
%autoreload 2

from rpm import RPMEmpiricalMarginals, EmpiricalDistribution, LogPartition_gauss_diagonal, LogPartition_vonMises
from rpm import ExpFam, ConditionalExpFam, SemiparametricConditionalExpFam, RPM, RPVAE
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals

import torch
import numpy as np

dtype = torch.float

J = 2                           # two marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
dim_T = 2                       # dimension of sufficient statistics

N = 1000

# currently playing with either Gaussian or Exponential marginals
marginals = 'none'
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    pxind = IndependentMarginal(pxjs, dims=dim_js)
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
    xjs = px.sample_n(N)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
    pxind = IndependentMarginal(pxjs, dims=dim_js)
    xjs = px.sample_n(N)
elif marginals == 'none':
    Z = 1. * np.random.normal(size=(N, dim_Z))
    def link(Z):
        return np.stack([Z, np.tanh(Z)], axis=1)
    xjs = [link(Z)[:,j] + 0.1 * np.random.normal(size=(N, dim_js[j])) for j in range(J)]
    xjs = [torch.tensor(xj,dtype=dtype) for xj in xjs]
else:
    raise Exception('marginals not implemented')
pxjs = RPMEmpiricalMarginals(xjs)


# define Gaussian prior in natural parametrization
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[...,:d], -torch.nn.Softplus()(x[...,d:])],axis=-1)
log_partition = LogPartition_gauss_diagonal(d=dim_Z)

# define Gaussian factors fj(Z|xj) in natural parametrization

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

latent_prior = ExpFam(natparam=torch.normal(mean=0.0, std=torch.ones(dim_T).reshape(1,-1)),
                                    log_partition=log_partition, activation_out=activation_out)

natparam_models = [Net(dim_js[j], dim_T, n_hidden=50, activation_out=activation_out) for j in range(J)]
rec_factors = [ConditionalExpFam(model=m, log_partition=log_partition) for m in natparam_models]

ivi_natparam_models = [Net(sum(dim_js), dim_T, n_hidden=50, activation_out=activation_out) for j in range(J)]
ivi_rec_models = [ConditionalExpFam(model=m, log_partition=log_partition) for m in ivi_natparam_models]

#nu = torch.nn.parameter.Parameter(activation_out(torch.normal(mean=0.0, std=torch.ones(N, J, dim_T)/1000.)))
nu = ivi_rec_models

rpvae = RPVAE( rec_factors, latent_prior=latent_prior, px=pxjs, nu=nu)


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(rpvae.parameters(), lr=1e-3)

epochs = 1
batch_size = 32

ds = torch.utils.data.TensorDataset(*xjs, torch.arange(N))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = rpvae.training_step(batch=batch[:-1], idx_data=batch[-1], batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
epochs = 2000

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = rpvae.training_step(batch=batch[:-1], idx_data=batch[-1], batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity

if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
    if J == 2:
        XX,YY = torch.meshgrid(*xxs)
        xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
        xgrid = xgrid.reshape(*xgrid.shape, 1)
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
    if J == 2:
        XX,YY = torch.meshgrid(*xxs)
        xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
        xgrid = xgrid.reshape(*xgrid.shape, 1)
    log_p0 = torch.log(torch.Tensor(rates)[0]) - rates[0] * xxs[0].reshape(-1,1) + torch.log(torch.Tensor(rates)[1]) - rates[1] * xxs[1].reshape(1,-1)
elif marginals == 'none':
    xxs = [torch.linspace(-3, 3,100), torch.linspace(-1.1, 1.1,100)]    
    if J == 2:
        XX,YY = torch.meshgrid(*xxs)
        xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
        xgrid = xgrid.reshape(*xgrid.shape, 1)
    kdes = [KernelDensity(kernel="gaussian", bandwidth='scott').fit(xjs[j]) for j in range(J)]
    log_pj0s = [kdes[j].score_samples(xgrid[:,j]) for j in range(J)]
    log_p0 = torch.tensor(sum([log_pj0s[j].reshape(100,100) for j in range(J)]), dtype=dtype)

log_w = rpvae.elbo_innervi([xgrid[:,j] for j in range(J)]).reshape(100,100)
#log_w = rpvae_test.elbo(xjs=[xgrid[:,j] for j in range(J)], idx_data=torch.arange(len(xgrid))).reshape(100,100)
logpx = log_w + log_p0


plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
try: 
    logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100).T
    plt.imshow(np.exp(logpx_true), origin='lower', 
               extent=(xgrid[:,0].min(), xgrid[:,0].max(), xgrid[:,1].min(), xgrid[:,1].max()), aspect='auto')
    #plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    #plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)
    plt.colorbar()
except:
    plt.plot(xjs[0], xjs[1], '.')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.title('true p(x) and samples')
    
plt.subplot(1,3,2)
plt.imshow(np.exp(logpx.detach().numpy()).T, origin='lower', 
           extent=(xgrid[:,0].min(), xgrid[:,0].max(), xgrid[:,1].min(), xgrid[:,1].max()), aspect='auto')
#plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.title('p(x) of learned amortized RPM')
#plt.plot(xjs[1], xjs[0], 'r.')
try:
    logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100).T
    plt.colorbar()
except:
    pass

plt.subplot(1,3,3)
rec_factors, prior = rpvae.joint_model[0], rpvae.joint_model[1]
eta0 = prior.nat_param
phi0 = prior.phi()
etajs_all = rpvae.factorNatParams(eta_off=eta0) # N-by-J-by-T
phijs_all = torch.stack([rec_factors[j].log_partition(etajs_all[:,j]) for j in range(rpvae.J)],axis=1)

Z = torch.linspace(-5, 5, 200)
tZ = torch.stack([Z, Z**2], axis=1)
pZ = torch.exp((eta0 * tZ).sum(axis=-1) - phi0)
plt.plot(Z.detach().numpy(), pZ.detach().numpy(), label='prior p(Z)')

for j in range(J):
    etaj_all = etajs_all[:,j]
    phij_all = phijs_all[:,j]
    fj = torch.exp((etaj_all.unsqueeze(0) * tZ.unsqueeze(1)).sum(axis=-1) - phij_all.unsqueeze(0))
    Fj = fj.mean(axis=1)
    plt.plot(Z.detach().numpy(), Fj.detach().numpy(), label='Fj(Z), j='+str(j+1))
plt.xlabel('Z')
plt.ylabel('density')
plt.title('Z-marginals')

eta_q, eta_j = rpvae.comp_eta_q(xjs=[pxj.x for pxj in rpvae.joint_model[2].pxjs])
phi_q = prior.log_partition(eta_q)
q = torch.exp((eta_q.unsqueeze(0) * tZ.unsqueeze(1)).sum(axis=-1) - phi_q.unsqueeze(0))
Q = q.mean(axis=1)
plt.plot(Z.detach().numpy(), Q.detach().numpy(), ':', label='Q(Z)')
plt.legend()

plt.show()

In [None]:
plt.figure(figsize=(6,4))
for j in range(J):
    etaj = rpvae.joint_model[0][j](xgrid[:,j]).detach().numpy()
    sig2 = -2.0 / etaj[:,1]
    mu = etaj[:,0]*sig2
    plt.subplot(2,2,1+j*2)
    plt.plot(xgrid[:,j].detach().numpy(), mu, '.')
    if j == 0 :
        plt.title(r'factor posterior mean $\mu_j(x_j)$')
    plt.ylabel('j = ' + str(j+1))
    plt.subplot(2,2,2+j*2)
    plt.plot(xgrid[:,j].detach().numpy(), sig2, '.')
    if j == 0 :
        plt.title(r'factor posterior variance $\sigma_j^2(x_j)$')
plt.show()

In [None]:
prior = rpvae.joint_model[1]
eta_q, eta_j = rpvae.comp_eta_q(xjs)
eta0, phi0 = prior.nat_param, rpvae.joint_model[1].phi()
Z = torch.linspace(-20, 20, 2000)
tZ = torch.stack([Z, Z**2], axis=1)
pZ = torch.exp((eta0 * tZ).sum(axis=-1) - phi0)

phi_q = prior.log_partition(eta_q)
log_q = ((eta_q.unsqueeze(0) * tZ.unsqueeze(1) ).sum(axis=-1)- phi_q.unsqueeze(0))
q = torch.exp(log_q)
plt.plot(Z.detach().numpy(), q[:,:5].detach().numpy(), ':', label='Q(Z)')

log_pZ = ((eta0.unsqueeze(0) * tZ ).sum(axis=-1)- prior.phi())[0]
pZ = torch.exp(log_pZ)
plt.plot(Z.detach().numpy(), pZ.detach().numpy())

rec_factors = rpvae.joint_model[0]
phijs = torch.stack([m.phi_x(xj, eta_off=eta0) for m,xj in zip(rec_factors,xjs)],axis=1)

log_fjZ = ((eta_j.unsqueeze(0) * tZ.unsqueeze(1).unsqueeze(2) ).sum(axis=-1)- phijs.unsqueeze(0))
fjZ =  torch.exp(log_fjZ)
plt.plot(Z.detach().numpy(), fjZ[:,:5,0].detach().numpy(), '--')

plt.legend()

plt.show()

print((fjZ.sum(axis=0) * torch.diff(Z)[0]).unique())


log_q_ = (1-J) * log_pZ.unsqueeze(1) + log_fjZ.sum(axis=-1)
q_ = torch.exp(log_q_)
plt.plot(Z.detach().detach().numpy(), q_[:,:5].detach().numpy(), ':')
plt.show()

log_q_norm_num = torch.log(q_.sum(axis=0) * torch.diff(Z)[0])

q_log_normalizer = rpvae.comp_q_log_normalizer(eta_q, eta_j, eta0)

off = q_log_normalizer.max() - log_q_norm_num.max()

plt.plot([log_q_norm_num.detach().numpy().min(), log_q_norm_num.detach().numpy().max()], 
         [log_q_norm_num.detach().numpy().min(), log_q_norm_num.detach().numpy().max()], 
         'k-', linewidth=0.5)
plt.plot(log_q_norm_num.detach().numpy(), (- off + q_log_normalizer).detach().numpy(), '.')
plt.show()


plt.plot(phi_q.detach().numpy(), q_log_normalizer.detach().numpy(), '.')
plt.show()

# VI RPM

In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution, LogPartition_gauss_diagonal, LogPartition_vonMises
from rpm import ExpFam, ConditionalExpFam, SemiparametricConditionalExpFam, RPM, RPVAE
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals

import torch
import numpy as np

dtype = torch.float

J = 2                           # two marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
dim_T = 2                       # dimension of sufficient statistics

N = 1000


# currently playing with either Gaussian or Exponential marginals
marginals = 'none' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)] 
    pxind = IndependentMarginal(pxjs, dims=dim_js)
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
    xjs = px.sample_n(N)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
    pxind = IndependentMarginal(pxjs, dims=dim_js)
    xjs = px.sample_n(N)
elif marginals == 'none':
    Z = 2. * np.random.normal(size=(N, dim_Z))
    def link(Z):
        return np.stack([Z, np.tanh(Z)], axis=1)

    xjs = [link(Z)[:,j] + 0.1 * np.random.normal(size=(N, dim_js[j])) for j in range(J)]
    xjs = [torch.tensor(xj,dtype=dtype) for xj in xjs]
else: 
    raise Exception('marginals not implemented')
pxjs = RPMEmpiricalMarginals(xjs)


In [None]:
# define Gaussian prior in natural parametrization  
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)
log_partition = LogPartition_gauss_diagonal(d=dim_Z)

# define Gaussian factors fj(Z|xj) in natural parametrization  

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

latent_prior = ExpFam(natparam=torch.normal(mean=0.0, std=torch.ones(dim_T).reshape(1,-1)),
                                    log_partition=log_partition, activation_out=activation_out)

natparam_models = [Net(dim_js[j], dim_T, n_hidden=50, activation_out=activation_out) for j in range(J)]
rec_factors = [ConditionalExpFam(model=m, log_partition=log_partition) for m in natparam_models]

q =  SemiparametricConditionalExpFam(natparams=torch.normal(mean=0.0, std=torch.ones(N, dim_T)), 
                                     log_partition=log_partition, activation_out=activation_out)

rpm = RPM( rec_factors, latent_prior=latent_prior, px=pxjs, q=q)

In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(rpm.parameters(), lr=1e-3)

epochs = 1
batch_size = 32

ds = torch.utils.data.TensorDataset(*xjs, torch.arange(N))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = rpm.training_step(batch=batch[:-1], idx_data=batch[-1], batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
epochs = 2000
batch_size = 32

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = rpm.training_step(batch=batch[:-1], idx_data=batch[-1], batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:

q_test =  SemiparametricConditionalExpFam(natparams=torch.normal(mean=0.0, std=torch.ones(len(xgrid), dim_T)), 
                                     log_partition=log_partition, activation_out=activation_out)

for p in rpm.joint_model.parameters():
    p.requires_grad = False

#optimizer = torch.optim.Adam(q_test.parameters(), lr=1e-3)
q_train = rpm.q 
rpm.q = q_test

epochs = 1
batch_size = 32

ds_test = torch.utils.data.TensorDataset(*[xgrid[:,j] for j in range(rpm.J)], torch.arange(len(xgrid)))
dl_test = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl_test:
        optimizer.zero_grad()
        loss = rpm.training_step(batch=batch[:-1], idx_data=batch[-1], batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()

for p in rpm.joint_model.parameters():
    p.requires_grad = True


In [None]:
rpvae_ = RPVAE( rec_factors, latent_prior=latent_prior, px=pxjs, nu=None)

In [None]:
import matplotlib.pyplot as plt

if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
    log_p0 = torch.log(torch.Tensor(rates)[0]) - rates[0] * xxs[0].reshape(-1,1) + torch.log(torch.Tensor(rates)[1]) - rates[1] * xxs[1].reshape(1,-1)
elif marginals == 'none':
    xxs = [torch.linspace(-3, 3,100), torch.linspace(-1.2, 1.2, 100)]    
    kdes = [KernelDensity(kernel="gaussian", bandwidth='scott').fit(xjs[j]) for j in range(J)]
    log_pj0s = [kdes[j].score_samples(xgrid[:,j]) for j in range(J)]
    log_p0 = torch.tensor(sum([log_pj0s[j].reshape(100,100) for j in range(J)]), dtype=dtype)
if J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    xgrid = xgrid.reshape(*xgrid.shape, 1)

log_w = rpvae_.elbo_innervi([xgrid[:,j] for j in range(J)]).reshape(100,100)
#log_w = rpvae_test.elbo(xjs=[xgrid[:,j] for j in range(J)], idx_data=torch.arange(len(xgrid))).reshape(100,100)
logpx = log_w + log_p0


plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
try: 
    logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100).T
    plt.imshow(np.exp(logpx_true), origin='lower', 
               extent=(xgrid[:,0].min(), xgrid[:,0].max(), xgrid[:,1].min(), xgrid[:,1].max()), aspect='auto')
    #plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    #plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)
    plt.colorbar()
except:
    plt.plot(xjs[0], xjs[1], '.')
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.title('true p(x) and samples')
    
plt.subplot(1,3,2)
plt.imshow(np.exp(logpx.detach().numpy()).T, origin='lower', 
           extent=(xgrid[:,0].min(), xgrid[:,0].max(), xgrid[:,1].min(), xgrid[:,1].max()), aspect='auto')
#plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.title('p(x) under dRPM')
#plt.plot(xjs[1], xjs[0], 'r.')
plt.colorbar()

plt.subplot(1,3,3)
etajs_all = rpvae_.factorNatParams() # N-by-J-by-T
rec_factors, prior = rpvae_.joint_model[0], rpvae_.joint_model[1]
phijs_all = torch.stack([rec_factors[j].log_partition(etajs_all[:,j]) for j in range(rpvae_.J)],axis=1)
eta0 = prior.nat_param
phi0 = prior.phi()

Z = torch.linspace(-0.5, 0.5, 200)
tZ = torch.stack([Z, Z**2], axis=1)
pZ = torch.exp((eta0 * tZ).sum(axis=-1) - phi0)
plt.plot(Z.detach().numpy(), pZ.detach().numpy(), label='prior p(Z)')

for j in range(J):
    etaj_all = etajs_all[:,j]
    phij_all = phijs_all[:,j]
    fj = torch.exp((etaj_all.unsqueeze(0) * tZ.unsqueeze(1)).sum(axis=-1) - phij_all.unsqueeze(0))
    Fj = fj.mean(axis=1)
    plt.xlabel('Z')
    plt.ylabel('Fj(Z)')
    plt.plot(Z.detach().numpy(), Fj.detach().numpy(), label='Fj(Z), j='+str(j+1))

eta_q, eta_j = rpvae_.comp_eta_q(xjs=[pxj.x for pxj in rpvae_.joint_model[2].pxjs])
phi_q = prior.log_partition(eta_q)
q = torch.exp((eta_q.unsqueeze(0) * tZ.unsqueeze(1)).sum(axis=-1) - phi_q.unsqueeze(0))
Q = q.mean(axis=1)
plt.plot(Z.detach().numpy(), Q.detach().numpy(), ':', label='Q(Z)')
plt.legend()

plt.show()

In [None]:
etajs_all = rpm.factorNatParams()                     # N-by-J-by-T
phijs_all = torch.stack([rpm.joint_model[0][j].log_partition(etajs_all[:,j]) for j in range(rpm.J)],axis=1)
eta0 = rpm.joint_model[1].nat_param
phi0 = rpm.joint_model[1].phi()

Z = torch.linspace(-5, 5, 200)
tZ = torch.stack([Z, Z**2], axis=1)


plt.figure(figsize=(6,3))

pZ = torch.exp((eta0 * tZ).sum(axis=-1) - phi0)

plt.plot(Z.detach().numpy(), pZ.detach().numpy(), label='prior p(Z)')
for j in range(J):
    etaj_all = etajs_all[:,j]
    phij_all = phijs_all[:,j]
    fj = torch.exp((etaj_all.unsqueeze(0) * tZ.unsqueeze(1)).sum(axis=-1) - phij_all.unsqueeze(0))
    Fj = fj.mean(axis=1)
    plt.xlabel('Z')
    plt.ylabel('Fj(Z)')
    plt.plot(Z.detach().numpy(), Fj.detach().numpy(), '--', label='Fj(Z), j='+str(j+1))

plt.legend()
plt.show()