In [None]:
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals
from discreteRPM import discreteRPM, Prior_discrete, RecognitionFactor_discrete

import torch
import numpy as np

J = 2                           # three marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
K = 7
dim_T = K                       # dimension of sufficient statistics


# currently playing with either Gaussian or Exponential marginals
marginals = 'exponential' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    #A = np.random.normal(size=(J,J))
    #P = A.dot(A.T)
    #P = P / np.sqrt(np.outer(np.diag(P), np.diag(P)))    
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
else: 
    raise Exception('marginals not implemented')
pxind = IndependentMarginal(pxjs, dims=dim_js)


# define Gaussian prior in natural parametrization  
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)

# define Gaussian factors fj(Z|xj) in natural parametrization  

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

natparam_models = [Net(dim_js[j], K, n_hidden=13) for j in range(J)]
rec_models = [RecognitionFactor_discrete(model=m) for m in natparam_models]

# constsruct implicit RPM
drpm = discreteRPM( rec_models, latent_prior=Prior_discrete(param=torch.randn(size=(K,))), pxjs=pxjs )

# check analytical gradients numerically

With loss $\mathcal{L}_N(\theta) = \sum_n \omega_\theta(x^{(n)})$, we can get gradients $\nabla\mathcal{L}_N(\theta)$ either via auto-differentiation using $\omega_\theta(x) = \int \prod_j \frac{f_{\theta_j}(Z|x_j)}{F_{\theta_j}(Z)} p_\theta(Z) dZ$ or alternatively via

$\nabla{}\mathcal{L}_N(\theta) = \sum_n \frac{\partial\eta_j^{(n)}}{\partial\theta_j}^\top \left( \mathbb{E}[t(Z)|x^{(n)}, \theta] - \frac{\tilde{p}_{N,\theta}(x_j^{(n)})}{p_j(x_j^{(n)})} \tilde{\mathbb{E}}[t(Z)|x_j^{(n)}, \theta] + (\frac{\tilde{p}_{N,\theta}(x_j^{(n)})}{p_j(x_j^{(n)})}-1) \frac{\partial\Phi(\eta_j(x_j^{(n)}))}{\partial\theta_j}\right) $

where $\tilde{p}$ signifies the RPM with substituted prior $p_\theta(Z) \leftarrow \tilde{p}_{N,\theta}(Z) = \frac{1}{N}\sum_n p_{N,\theta}(Z|x^{(n)})$ and otherwise same model components $f_{\theta_j}(Z|xj), \ p_j(x_j)$.

Here we check the validity of of the above form written in terms of posterior expectations over $t(Z)$ against pytorch auto-diff gradients using $\omega_\theta(x)$, using that in the discrete we can easily compute all involved integrals over $Z$.

In [None]:
from functorch import make_functional
import matplotlib.pyplot as plt

N = 23
xjs = [torch.randn(size=(N,dim_js[j])) for j in range(drpm.J)]
logw, posterior, w_tilda_j, posterior_tilda_j  = drpm.eval(xjs)

drpm.zero_grad()
loss = drpm.eval(xjs)[0].mean() # average negative log(w(x))
loss.backward()

with torch.no_grad():

    j = 0
    m = drpm.rec_models[j]
    frec_model, params = make_functional(m)
    detaj_dthetajs = torch.func.jacrev(frec_model, argnums=0)(params, xjs[j])
    
    tmp = (posterior-w_tilda_j[:,j].reshape(-1,1)*posterior_tilda_j[:,j] + (w_tilda_j[:,j].reshape(-1,1)-1.0)*torch.exp(m.log_probs(xjs[j])))
    grad_ana = []
    for k in range(len(params)):
        detaj_dthetajk = detaj_dthetajs[k]
        tmpk = detaj_dthetajk*(tmp.reshape(*tmp.shape, 1, 1)) if detaj_dthetajk.ndim==4 else detaj_dthetajk*(tmp.reshape(*tmp.shape, 1))
        grad_ana.append(
            1. * tmpk.sum(axis=1).mean(axis=0)
        )
        print("\n analytical \n")
        print(grad_ana[-1])
        plt.plot(grad_ana[-1].detach())
        for i,p in enumerate(drpm.parameters()):
            if i == k:
                plt.plot(p.grad.detach(), '--')
                print("\n numerical \n")
                print(p.grad)
                break
        plt.show()


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(drpm.parameters(), lr=1e-4)

epochs = 500
N = 100
batch_size = 10

class RPMDataset(torch.utils.data.Dataset):
    def __init__(self,xjs):
        self.J = len(xjs)
        assert all([len(xjs[0]) == len(xjs[j]) for j in range(self.J)])
        self.xjs = xjs
    def __len__(self):
        return len(self.xjs[0])
    def __getitem__(self,idx):
        return [self.xjs[j][idx] for j in range(self.J)]

ds = RPMDataset(px.sample_n(n=N))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = drpm.training_step(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
import matplotlib.pyplot as plt

# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 3:
    XX,YY,ZZ = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten(), ZZ.flatten()], axis=-1)
    logpx = drpm.eval([XX.reshape(-1,1), YY.reshape(-1,1), ZZ.reshape(-1,1)])[0].reshape(100,100,100)
elif J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    logpx = drpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)])[0].reshape(100,100)

if J == 2:
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    plt.imshow(logpx.detach().numpy(), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under dRPM with copula loss')
    #plt.plot(xjs05[1], xjs05[0], 'ro')
    #xjs= pxind.sample_n(50)
    #plt.plot(xjs[1], xjs[0], 'kx')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')

    plt.subplot(1,2,2)
    plt.contour(YY, XX , logpx.detach().numpy(), levels=5)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under dRPM with copula loss')
    #plt.plot(xjs05[1], xjs05[0], 'ro')
    #xjs= pxind.sample_n(50)
    #plt.plot(xjs[1], xjs[0], 'kx')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')



# RPM without conditional independence assumption
- dependency structure is overrated !

In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution
from discreteRPM import discreteRPM, discretenonCondIndRPM, Prior_discrete, RecognitionFactor_discrete, RecognitionFactor_scaled_discrete
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals

import torch
import numpy as np

J = 2                           # three marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
K = 20
dim_T = K                       # dimension of sufficient statistics


# currently playing with either Gaussian or Exponential marginals
marginals = 'exponential' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    #A = np.random.normal(size=(J,J))
    #P = A.dot(A.T)
    #P = P / np.sqrt(np.outer(np.diag(P), np.diag(P)))    
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
else: 
    raise Exception('marginals not implemented')
pxind = IndependentMarginal(pxjs, dims=dim_js)


# define Gaussian prior in natural parametrization  
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)

# define Gaussian factors fj(Z|xj) in natural parametrization  

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

#natparam_model = Net(sum(dim_js), K, n_hidden=13)
#rec_model = RecognitionFactor_discrete(model=natparam_model) 
natparam_model = Net(sum(dim_js), K+1, n_hidden=50)
rec_model = RecognitionFactor_scaled_discrete(model=natparam_model) 

N = 5000
xjs = px.sample_n(n=N)
pxj = RPMEmpiricalMarginals(xjs)

# constsruct implicit RPM
drpm = discretenonCondIndRPM( rec_model, latent_prior=Prior_discrete(param=torch.randn(size=(K,))), pxjs=pxj, full_F=False )


natparam_models = [Net(dim_js[j], K, n_hidden=50) for j in range(J)]
rec_models = [RecognitionFactor_discrete(model=m) for m in natparam_models]
drpm = discreteRPM( rec_models, latent_prior=Prior_discrete(param=torch.randn(size=(K,))), pxjs=pxjs)


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(drpm.parameters(), lr=1e-4)

epochs = 1000
batch_size = 32

xs = torch.stack(xjs, axis=1)

ds = torch.utils.data.TensorDataset(xs)
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = drpm.training_step(batch[0] if len(batch)==1 else batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
import matplotlib.pyplot as plt

# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    xgrid = xgrid.reshape(*xgrid.shape, 1)

drpm.full_F = True
log_w = drpm.eval([xgrid[:,j] for j in range(J)])[0].reshape(100,100)
logpx = log_w + torch.log(torch.Tensor(rates)[0]) - rates[0] * xxs[0].reshape(-1,1) + torch.log(torch.Tensor(rates)[1]) - rates[1] * xxs[1].reshape(1,-1)
drpm.full_F = False

logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100)

plt.figure(figsize=(16,6))

plt.subplot(1,3,1)
plt.imshow(np.exp(logpx_true), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('true log p(x) and samples')
plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)


plt.subplot(1,3,2)
plt.imshow(np.exp(logpx.detach().numpy()), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('log p(x) under dRPM')
#plt.plot(xjs[1], xjs[0], 'r.')

"""
plt.subplot(1,3,3)
m = drpm.rec_model
log_fxs = m.log_probs(xgrid)                                     # N^J       - K 
log_denom = torch.logsumexp(log_fxs,dim=0).reshape(1,-1) - np.log(len(xgrid)**J) # 1         - K
pOverF = (torch.exp(drpm.latent_prior.log_probs()).reshape(1,-1)/torch.exp(log_denom).reshape(1,-1)).detach().numpy()
posts = torch.exp(log_fxs).detach().numpy() * pOverF 
posts = posts.reshape(100,100,7)
z_posts = np.argmax(posts, axis=-1)
plt.imshow(z_posts)
plt.colorbar()
"""


In [None]:
plt.subplot(1,2,1)
plt.imshow(np.exp(logpx.detach().numpy()), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('log p(x) under dRPM')
plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)


plt.subplot(1,2,2)
plt.imshow(torch.exp(drpm.rec_model.model(xgrid)[:,-1]).detach().numpy().reshape(100,100), 
           origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')

plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('log p(x) under dRPM')

plt.show()

In [None]:
import matplotlib.pyplot as plt

# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    xgrid = xgrid.reshape(*xgrid.shape, 1)

    
drpm.full_F = True
log_w = drpm.eval([xgrid[:,j] for j in range(J)])[0].reshape(100,100)
drpm.full_F = False

log_p0 = torch.log(torch.Tensor(rates)[0]) - rates[0] * xxs[0].reshape(-1,1) + torch.log(torch.Tensor(rates)[1]) - rates[1] * xxs[1].reshape(1,-1)
logpx = log_w + log_p0


logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100)


plt.figure(figsize=(16,6))

plt.subplot(1,3,1)
plt.imshow(np.exp(logpx_true), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('true p(x) and samples')
plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)
plt.colorbar()


plt.subplot(1,3,2)
plt.imshow(np.exp(logpx.detach().numpy()), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('p(x) under dRPM')
#plt.plot(xjs[1], xjs[0], 'r.')
plt.colorbar()


plt.subplot(1,3,3)
log_fs = torch.stack([m(xgrid[:,j]) for j,m in enumerate(drpm.rec_models)], dim=1)
log_norm_f =  np.log(torch.exp(log_fs.sum(axis=1)).sum(axis=-1).detach().numpy().reshape(100,100))
log_norm = log_norm_f + log_p0.detach().numpy()

plt.imshow(np.exp(log_norm),
           origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title(r'normalization of $\prod_j \ f(Z|x_j)$')
plt.colorbar()



# RP-VAEs

In [None]:
import torch

class discreteRPVAE(torch.nn.Module):
    """
    Variational auto-encoder for recognition-Parametrized Model (RPM)
    """
    def __init__(self, rec_models, latent_prior, pxjs):

        super().__init__()
        self.J = len(rec_models)
        assert len(pxjs) == self.J

        self.rec_models = torch.nn.ModuleList(rec_models)
        self.latent_prior = latent_prior
        self.pxjs = pxjs

    def eval(self, xjs):
        J = self.J
        assert len(xjs) == J
        N = xjs[0].shape[0]
        assert all([xjs[j].shape[0] == N for j in range(self.J)])

        log_fnji = torch.stack([m.log_probs(xj)for m,xj in zip(self.rec_models, xjs)], axis=1)  # N-J-K
        Fji  = torch.exp(log_fnji).mean(axis=0)                                                 #   J-K
        log_prod_Fj = torch.log(Fji).sum(axis=0).reshape(1,-1)                                  # 1 - K
        log_prod_fj = log_fnji.sum(axis=1)                                                      # N - K
        log_prod_frac = log_prod_fj - log_prod_Fj                                               # N - K

        log_joint_factor = self.latent_prior.log_probs() + log_prod_frac                        # N - K
        logw = torch.log(torch.exp(log_joint_factor).sum(axis=-1))                              # N

        #posterior = torch.exp(log_joint_factor - logw.reshape(-1,1))                            # N - K

        #log_prior_tilda = torch.log(posterior.mean(axis=0)).reshape(1,-1)                       # 1 - K
        #log_joint_tilda = log_prior_tilda + log_prod_frac                                       # N - K
        #w_tilda = torch.exp(log_joint_tilda).sum(axis=-1)                                       # N

        #log_joint_tilda_j = log_fnji + (log_prior_tilda - torch.log(Fji)).reshape(1,*Fji.shape) # N-J-K        
        #w_tilda_j = torch.exp(log_joint_tilda_j).sum(axis=-1)                                   # N-J        
        #posterior_tilda_j = torch.exp(log_joint_tilda_j) / w_tilda_j.reshape(-1,J,1)            # N-J-K

        return logw

    def elbo(self, xjs):
        J = self.J
        assert len(xjs) == J
        N = xjs[0].shape[0]
        assert all([xjs[j].shape[0] == N for j in range(self.J)])

        log_fnji = torch.stack([m.log_probs(xj)for m,xj in zip(self.rec_models, xjs)], axis=1)  # N-J-K
        Fji  = torch.exp(log_fnji).mean(axis=0)                                                 #   J-K
        log_prod_Fj = torch.log(Fji).sum(axis=0).reshape(1,-1)                                  # 1 - K
        log_prod_fj = log_fnji.sum(axis=1)                                                      # N - K

        log_q = (1-J) * self.latent_prior.log_probs() + log_prod_fj                             # N - K
        lognorm_q = torch.logsumexp(log_q, axis=1).unsqueeze(-1)                                # N - 1 
        log_q = log_q - lognorm_q                                                               # N - K
        
        log_ratio = lognorm_q + (J * self.latent_prior.log_probs() - log_prod_Fj)               # N - K
        elbo = (torch.exp(log_q) * log_ratio).sum(axis=1)                                       # N 

        return elbo

    def training_step(self, batch, batch_idx):
        # score matching loss
        xjs = batch
        loss = - self.elbo(xjs).mean() 
        return loss.sum()


In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution
from discreteRPM import discreteRPM, discretenonCondIndRPM, Prior_discrete, RecognitionFactor_discrete, RecognitionFactor_scaled_discrete
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals

import torch
import numpy as np

J = 2                           # three marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
K = 20
dim_T = K                       # dimension of sufficient statistics


# currently playing with either Gaussian or Exponential marginals
marginals = 'exponential' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    #A = np.random.normal(size=(J,J))
    #P = A.dot(A.T)
    #P = P / np.sqrt(np.outer(np.diag(P), np.diag(P)))    
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
else: 
    raise Exception('marginals not implemented')
pxind = IndependentMarginal(pxjs, dims=dim_js)


# define Gaussian prior in natural parametrization  
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)

# define Gaussian factors fj(Z|xj) in natural parametrization  

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

natparam_models = [Net(dim_js[j], K, n_hidden=50) for j in range(J)]
rec_models = [RecognitionFactor_discrete(model=m) for m in natparam_models]
drpm = discreteRPVAE( rec_models, latent_prior=Prior_discrete(param=torch.randn(size=(K,))), pxjs=pxjs)


In [None]:
N = 5000
xjs = px.sample_n(n=N)


In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(drpm.parameters(), lr=1e-4)

epochs = 200
batch_size = 32

ds = torch.utils.data.TensorDataset(*xjs)
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = drpm.training_step(batch[0] if len(batch)==1 else batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()


In [None]:
import matplotlib.pyplot as plt

# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    xgrid = xgrid.reshape(*xgrid.shape, 1)

    
drpm.full_F = True
log_w = drpm.eval([xgrid[:,j] for j in range(J)]).reshape(100,100)
drpm.full_F = False

log_p0 = torch.log(torch.Tensor(rates)[0]) - rates[0] * xxs[0].reshape(-1,1) + torch.log(torch.Tensor(rates)[1]) - rates[1] * xxs[1].reshape(1,-1)
logpx = log_w + log_p0


logpx_true = px.log_probs(xgrid.detach().numpy().squeeze(-1)).reshape(100,100)


plt.figure(figsize=(16,6))

plt.subplot(1,3,1)
plt.imshow(np.exp(logpx_true), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('true p(x) and samples')
plt.plot(xjs[1], xjs[0], 'r.', markersize=0.5)
plt.colorbar()


plt.subplot(1,3,2)
plt.imshow(np.exp(logpx.detach().numpy()), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('p(x) under dRPM')
#plt.plot(xjs[1], xjs[0], 'r.')
plt.colorbar()


plt.subplot(1,3,3)
log_fs = torch.stack([m(xgrid[:,j]) for j,m in enumerate(drpm.rec_models)], dim=1)
log_norm_f =  np.log(torch.exp(log_fs.sum(axis=1)).sum(axis=-1).detach().numpy().reshape(100,100))
log_norm = log_norm_f + log_p0.detach().numpy()

plt.imshow(np.exp(log_norm),
           origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title(r'normalization of $\prod_j \ f(Z|x_j)$')
plt.colorbar()

