In [None]:
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals, LogPartition_gauss
from implicitRPM import ImplicitPrior_ExpFam, ImplicitRecognitionFactor_ExpFam, ImplicitRPM

In [None]:
import torch
import numpy as np

J = 2                           # three marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
dim_T = 2                       # dimension of sufficient statistics

# currently playing with either Gaussian or Exponential marginals
marginals = 'exponential' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    #A = np.random.normal(size=(J,J))
    #P = A.dot(A.T)
    #P = P / np.sqrt(np.outer(np.diag(P), np.diag(P)))    
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
else: 
    raise Exception('marginals not implemented')
pxind = IndependentMarginal(pxjs, dims=dim_js)


# define Gaussian prior in natural parametrization  
def activation_out(x,d=1): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
    return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)
log_partition_gauss = LogPartition_gauss(d=dim_Z,D=dim_T)
latent_prior = ImplicitPrior_ExpFam(natparam=torch.normal(mean=0.0, std=torch.ones(dim_T).reshape(1,-1)),
                                    log_partition=log_partition_gauss, activation_out=activation_out)

# define Gaussian factors fj(Z|xj) in natural parametrization  

class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_hidden, bias=True)
        self.fc3 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return self.activation_out(x)

natparam_models = [Net(dim_js[j], dim_T, n_hidden=10, activation_out=activation_out) for j in range(J)]
rec_models = [ImplicitRecognitionFactor_ExpFam(model=m, log_partition=log_partition_gauss) for m in natparam_models]

# constsruct implicit RPM
irpm = ImplicitRPM(rec_models, latent_prior, pxjs)

# ahem ... rejection sampling network initializations...
while torch.any(torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000)))) == torch.nan):
    irpm = ImplicitRPM(rec_models, latent_prior, pxjs)
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000))))

In [None]:
import matplotlib.pyplot as plt

# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 3:
    XX,YY,ZZ = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten(), ZZ.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1), ZZ.reshape(-1,1)]).reshape(100,100,100)
elif J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)

if J == 2:
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    plt.imshow(logpx.detach().numpy(), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    #plt.plot(xjs05[1], xjs05[0], 'ro')
    #xjs= pxind.sample_n(50)
    #plt.plot(xjs[1], xjs[0], 'kx')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')

    plt.subplot(1,2,2)
    plt.contour(YY, XX , logpx.detach().numpy(), levels=5)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    #plt.plot(xjs05[1], xjs05[0], 'ro')
    #xjs= pxind.sample_n(50)
    #plt.plot(xjs[1], xjs[0], 'kx')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')



In [None]:
# prior mean parameters before training
log_partition_gauss.grad(irpm.latent_prior.param)

In [None]:
# marginal mean parameters before training (they'll differ from prior mean above !)
[torch.mean(log_partition_gauss.grad(irpm(pxind.sample_n(n=10000))[j]),axis=0) for j in range(irpm.J)]

In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(irpm.parameters(), lr=1e-4)

epochs = 1000
N = 10000
batch_size = 1000

class RPMDataset(torch.utils.data.Dataset):
    def __init__(self,xjs):
        self.J = len(xjs)
        assert all([len(xjs[0]) == len(xjs[j]) for j in range(self.J)])
        self.xjs = xjs
    def __len__(self):
        return len(self.xjs[0])
    def __getitem__(self,idx):
        return [self.xjs[j][idx] for j in range(self.J)]

ds = RPMDataset(px.sample_n(n=N))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = irpm.training_step_sm(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=10000))))

In [None]:
# debugging - checking if loss coincides with numerical estimate

xjs = pxind.sample_n(100)

f1 = torch.func.jacrev(irpm.eval_sum, argnums=0)
f2 = torch.func.jacrev(f1)
loss_auto = 0.
loss_auto = loss_auto + 0.5*(f1(xjs)[0]**2 + f1(xjs)[1]**2).squeeze()
loss_auto = loss_auto + torch.diag(f2(xjs)[0][0].squeeze()) + torch.diag(f2(xjs)[1][1].squeeze())

plt.plot(irpm.loss_sm(xjs).detach())
plt.plot(loss_auto.detach(), '--')
plt.show()
irpm.loss_sm(xjs) - loss_auto


In [None]:
# prior mean after learning
log_partition_gauss.grad(irpm.latent_prior.param)

In [None]:
# marginal mean parameters before training (they should match now, that's what we trained for !)
[torch.mean(log_partition_gauss.grad(irpm(pxind.sample_n(n=10000))[j]),axis=0) for j in range(irpm.J)]

In [None]:
for j in range(J):

    plt.subplot(1,2,1)
    plt.plot(xxs[j], torch.exp(pxjs[j].log_prob(xxs[j])).detach().numpy(), label='true p(xj)')
    knotj = [k for k in range(J)]
    knotj.pop(j)
    fac = np.prod([xxs[k].diff()[0] for k in knotj])
    plt.plot(xxs[j], fac * torch.sum(torch.exp(logpx),dim=tuple(knotj)).detach().numpy(), label='est. p(xj)')
    plt.legend()
    plt.title(r'marginals $p(x_j)$')
    
    plt.subplot(1,2,2)
    mu = irpm.latent_prior.log_partition.grad(irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    xx = torch.linspace((m-3*torch.sqrt(sig2)).detach().numpy()[0], 
                        (m+3*torch.sqrt(sig2)).detach().numpy()[0], 100)
    prior = torch.distributions.normal.Normal(loc=m, scale=torch.sqrt(sig2))
    plt.plot(xx.detach(), torch.exp(prior.log_prob(xx)).detach().numpy(), label='p(Z)')
    xj = pxind.sample_n(n=1000)[j]
    mu = irpm.latent_prior.log_partition.grad(irpm.rec_models[j](xj)+irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    N = len(m)
    def fj(x):
        return sum([torch.exp(torch.distributions.normal.Normal(loc=m[n], scale=torch.sqrt(sig2[n])).log_prob(x)) for n in range(N)])/N
    plt.plot(xx.detach(), fj(xx).detach().numpy(), label='Fj(Z)')
    plt.legend()
    plt.title(r'prior $p(Z)$ vs factor evidence $F_j(Z)$')
    plt.show()

In [None]:
if J == 2:
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)
    plt.imshow(logpx.detach().numpy(), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    
    plt.subplot(1,2,2)
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.show()


In [None]:
N = len(ds)
epochs = 1000
batch_size = 1000

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = irpm.training_step_sm(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=10000))))

In [None]:
for j in range(J):

    plt.subplot(1,2,1)
    plt.plot(xxs[j], torch.exp(pxjs[j].log_prob(xxs[j])).detach().numpy(), label='true p(xj)')
    knotj = [k for k in range(J)]
    knotj.pop(j)
    fac = np.prod([xxs[k].diff()[0] for k in knotj])
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)
    plt.plot(xxs[j], fac * torch.sum(torch.exp(logpx),dim=tuple(knotj)).detach().numpy(), label='est. p(xj)')
    plt.legend()
    plt.title(r'marginals $p(x_j)$')
    
    plt.subplot(1,2,2)
    mu = irpm.latent_prior.log_partition.grad(irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    xx = torch.linspace((m-3*torch.sqrt(sig2)).detach().numpy()[0], 
                        (m+3*torch.sqrt(sig2)).detach().numpy()[0], 100)
    prior = torch.distributions.normal.Normal(loc=m, scale=torch.sqrt(sig2))
    plt.plot(xx.detach(), torch.exp(prior.log_prob(xx)).detach().numpy(), label='p(Z)')
    xj = pxind.sample_n(n=1000)[j]
    mu = irpm.latent_prior.log_partition.grad(irpm.rec_models[j](xj)+irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    N = len(m)
    def fj(x):
        return sum([torch.exp(torch.distributions.normal.Normal(loc=m[n], scale=torch.sqrt(sig2[n])).log_prob(x)) for n in range(N)])/N
    plt.plot(xx.detach(), fj(xx).detach().numpy(), label='Fj(Z)')
    plt.legend()
    plt.title(r'prior $p(Z)$ vs factor evidence $F_j(Z)$')
    plt.show()

In [None]:
if J == 2:
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)
    plt.imshow(logpx.detach().numpy(), origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    
    plt.subplot(1,2,2)
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.show()


In [None]:
N = len(ds)
epochs = 1000
batch_size = 1000

ls,t = np.zeros(epochs*(N//batch_size)),0
for i in range(epochs):
    for batch in dl:
        optimizer.zero_grad()
        loss = irpm.training_step_sm(batch, batch_idx=t)
        loss.backward()
        optimizer.step()
        ls[t] = loss.detach().numpy()
        t+=1
plt.plot(ls)
plt.show()
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=10000))))

In [None]:
for j in range(J):

    plt.subplot(1,2,1)
    if marginals == 'gaussian':
        xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
    elif marginals == 'exponential':
        xxs = [torch.linspace(0.001, 10/rates[j],1000) for j in range(J)]
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(1000,1000)
    logpx = logpx - torch.log(torch.exp(logpx).sum())    
    plt.plot(xxs[j], torch.exp(pxjs[j].log_prob(xxs[j])).detach().numpy(), label='true p(xj)')
    knotj = [k for k in range(J)]
    knotj.pop(j)
    fac = np.prod([xxs[k].diff()[0] for k in knotj])
    plt.plot(xxs[j], 1/fac * torch.sum(torch.exp(logpx),dim=tuple(knotj)).detach().numpy(), label='est. p(xj)')
    plt.legend()
    plt.title(r'marginals $p(x_j)$')
    
    plt.subplot(1,2,2)
    mu = irpm.latent_prior.log_partition.grad(irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    xx = torch.linspace((m-3*torch.sqrt(sig2)).detach().numpy()[0], 
                        (m+3*torch.sqrt(sig2)).detach().numpy()[0], 100)
    prior = torch.distributions.normal.Normal(loc=m, scale=torch.sqrt(sig2))
    plt.plot(xx.detach(), torch.exp(prior.log_prob(xx)).detach().numpy(), label='p(Z)')
    xj = pxind.sample_n(n=1000)[j]
    mu = irpm.latent_prior.log_partition.grad(irpm.rec_models[j](xj)+irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    N = len(m)
    def fj(x):
        return sum([torch.exp(torch.distributions.normal.Normal(loc=m[n], scale=torch.sqrt(sig2[n])).log_prob(x)) for n in range(N)])/N
    plt.plot(xx.detach(), fj(xx).detach().numpy(), label='Fj(Z)')
    plt.legend()
    plt.title(r'prior $p(Z)$ vs factor evidence $F_j(Z)$')
    plt.show()

In [None]:
if J == 2:
    plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(1000,1000)
    logpx = logpx - torch.log(torch.exp(logpx).sum())
    plt.imshow(logpx.detach().numpy(), origin='lower', extent=(0., 10/rates[1], 0., 10/rates[0]), aspect='auto')
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 10/rates[1], 0., 10/rates[0]))
    
    plt.subplot(1,2,2)
    plt.contour(YY, XX , logpx.detach().numpy(), levels=10)
    plt.ylabel(r'$x_1$')
    plt.xlabel(r'$x_2$')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.title('log p(x) under iRPM with copula loss')
    xjs = px.sample_n(n=1000)
    plt.plot(xjs[1], xjs[0], 'r.')
    plt.axis((0., 3/rates[1], 0., 3/rates[0]))
    plt.show()


# second part: combine copula loss with maximum likelihood of implicit RPM

In [None]:
import copy
irpm0 = copy.deepcopy(irpm)

In [None]:
# constsruct implicit RPM
irpm = ImplicitRPM(rec_models, latent_prior, pxjs)

# ahem ... rejection sampling network initializations...
while torch.any(torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000)))) == torch.nan):
    irpm = ImplicitRPM(rec_models, latent_prior, pxjs)
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000))))

In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(irpm.parameters(), lr=1e-3)
T = 1000
ls = np.zeros(T)
for t in range(T):
    optimizer.zero_grad()
    xjs = px.sample_n(n=1000)
    loss = irpm.training_step(batch=xjs, batch_idx=0, lmbda=10.0)
    loss.backward()
    optimizer.step()
    ls[t] = loss.detach().numpy()
plt.semilogy(ls)
plt.show()
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000))))

In [None]:
# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 3:
    XX,YY,ZZ = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten(), ZZ.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1), ZZ.reshape(-1,1)]).reshape(100,100,100)
elif J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)


In [None]:
for j in range(J):

    plt.subplot(1,2,1)
    plt.plot(xxs[j], torch.exp(pxjs[j].log_prob(xxs[j])).detach().numpy(), label='true p(xj)')
    knotj = [k for k in range(J)]
    knotj.pop(j)
    fac = np.prod([xxs[k].diff()[0] for k in knotj])
    plt.plot(xxs[j], fac * torch.sum(torch.exp(logpx),dim=tuple(knotj)).detach().numpy(), label='est. p(xj)')
    plt.legend()
    plt.title(r'marginals $p(x_j)$')
    
    plt.subplot(1,2,2)
    mu = irpm.latent_prior.log_partition.grad(irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    xx = torch.linspace((m-3*torch.sqrt(sig2)).detach().numpy()[0], 
                        (m+3*torch.sqrt(sig2)).detach().numpy()[0], 100)
    prior = torch.distributions.normal.Normal(loc=m, scale=torch.sqrt(sig2))
    plt.plot(xx.detach(), torch.exp(prior.log_prob(xx)).detach().numpy(), label='p(Z)')
    xj = xjs[j] #pxind.sample_n(n=1000)[j]
    mu = irpm.latent_prior.log_partition.grad(irpm.rec_models[j](xj)+irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    N = len(m)
    def fj(x):
        return sum([torch.exp(torch.distributions.normal.Normal(loc=m[n], scale=torch.sqrt(sig2[n])).log_prob(x)) for n in range(N)])/N
    plt.plot(xx.detach(), fj(xx).detach().numpy(), label='Fj(Z)')
    plt.legend()
    plt.title(r'prior $p(Z)$ vs factor evidence $F_j(Z)$')
    plt.show()

In [None]:
models = [irpm0, irpm]
plt.figure(figsize=(12,7))
if J == 2:
    for i in range(2):
        logpx = models[i].eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100).detach().numpy()
        plt.subplot(1,2,i+1)
        plt.imshow(logpx, origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
        plt.contour(YY, XX , logpx, levels=5)
        plt.ylabel(r'$x_1$')
        plt.xlabel(r'$x_2$')
        xjs = px.sample_n(n=50)
        plt.plot(xjs[1], xjs[0], 'r.')
        plt.axis((0., 3/rates[1], 0., 3/rates[0]))
        if i == 0:
            plt.title('log p(x) under iRPM with copula loss')
        elif i == 1:
            plt.title('log p(x) under iRPM with full loss')
    plt.show()

In [None]:
import matplotlib.pyplot as plt
optimizer = torch.optim.Adam(irpm.parameters(), lr=1e-3)
T = 140
ls = np.zeros(T)
for t in range(T):
    optimizer.zero_grad()
    xjs = px.sample_n(n=1000)
    loss = irpm.training_step(batch=xjs, batch_idx=0, lmbda=10.0)
    loss.backward()
    optimizer.step()
    ls[t] = loss.detach().numpy()
plt.semilogy(ls)
plt.show()
torch.mean(torch.exp(irpm.eval(pxind.sample_n(n=1000))))

In [None]:
# super lazy marginals for the RPM: numerically integrate over p(x1, ..., xJ) via below grid for J=2 or J=3
if marginals == 'gaussian':
    xxs = [torch.linspace(locs[j]-3*scales[j], locs[j]+3*scales[j],100) for j in range(J)]
elif marginals == 'exponential':
    xxs = [torch.linspace(0.001, 3/rates[j],100) for j in range(J)]
if J == 3:
    XX,YY,ZZ = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten(), ZZ.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1), ZZ.reshape(-1,1)]).reshape(100,100,100)
elif J == 2:
    XX,YY = torch.meshgrid(*xxs)
    xgrid = torch.stack([XX.flatten(), YY.flatten()], axis=-1)
    logpx = irpm.eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100)


In [None]:
for j in range(J):

    plt.subplot(1,2,1)
    plt.plot(xxs[j], torch.exp(pxjs[j].log_prob(xxs[j])).detach().numpy(), label='true p(xj)')
    knotj = [k for k in range(J)]
    knotj.pop(j)
    fac = np.prod([xxs[k].diff()[0] for k in knotj])
    plt.plot(xxs[j], fac * torch.sum(torch.exp(logpx),dim=tuple(knotj)).detach().numpy(), label='est. p(xj)')
    plt.legend()
    plt.title(r'marginals $p(x_j)$')
    
    plt.subplot(1,2,2)
    mu = irpm.latent_prior.log_partition.grad(irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    xx = torch.linspace((m-3*torch.sqrt(sig2)).detach().numpy()[0], 
                        (m+3*torch.sqrt(sig2)).detach().numpy()[0], 100)
    prior = torch.distributions.normal.Normal(loc=m, scale=torch.sqrt(sig2))
    plt.plot(xx.detach(), torch.exp(prior.log_prob(xx)).detach().numpy(), label='p(Z)')
    xj = xjs[j] #pxind.sample_n(n=1000)[j]
    mu = irpm.latent_prior.log_partition.grad(irpm.rec_models[j](xj)+irpm.latent_prior.param)
    m, sig2 = mu[:,0], mu[:,1] - mu[:,0]**2
    N = len(m)
    def fj(x):
        return sum([torch.exp(torch.distributions.normal.Normal(loc=m[n], scale=torch.sqrt(sig2[n])).log_prob(x)) for n in range(N)])/N
    plt.plot(xx.detach(), fj(xx).detach().numpy(), label='Fj(Z)')
    plt.legend()
    plt.title(r'prior $p(Z)$ vs factor evidence $F_j(Z)$')
    plt.show()

In [None]:
models = [irpm, irpm]
plt.figure(figsize=(7,7))
i = 0
logpx = models[i].eval([XX.reshape(-1,1), YY.reshape(-1,1)]).reshape(100,100).detach().numpy()
plt.imshow(logpx, origin='lower', extent=(0., 3/rates[1], 0., 3/rates[0]), aspect='auto')
plt.contour(YY, XX , logpx, levels=5)
plt.ylabel(r'$x_1$')
plt.xlabel(r'$x_2$')
xjs = px.sample_n(n=50)
plt.plot(xjs[1], xjs[0], 'r.')
plt.axis((0., 3/rates[1], 0., 3/rates[0]))
plt.title('log p(x) under iRPM with full loss')
plt.show()

# test-bed playground

In [None]:
# Gaussian copula for data generation (rather than independence between marginals)
import scipy.stats as stats
A = np.random.normal(size=(J,J))
P = A.dot(A.T)
P = P / np.sqrt(np.outer(np.diag(P), np.diag(P)))
N = 100000
zz = stats.norm.cdf(A.dot(np.random.normal(size=(J,N))))
xx = (-1.0/np.array(rates).reshape(J,1))*np.log(zz)


for j in range(J):
    plt.subplot(1,3,j+1)
    plt.hist(xx[j,:], density=True)
    plt.title(rates[j])
plt.show()