In [1]:
%matplotlib notebook

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS

import torch
import numpy as np
from matplotlib import pyplot

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

#pyro.enable_validation(True) 

In [2]:
nPerson = 2
nModesPerPerson = 3
nSpectralCoeff = 2
nSamplesPerMode = 3

modeCovarianceFactor = 100
personCovarianceFactor = 1
dataCovarianceFactor = .01

modeMean = torch.zeros(nSpectralCoeff)
modeCovariance = modeCovarianceFactor * torch.eye(nSpectralCoeff)
personCovariance = personCovarianceFactor * torch.eye(nSpectralCoeff)
dataCovariance = dataCovarianceFactor * torch.eye(nSpectralCoeff)

In [3]:
def generateData(nPerson, nModesPerPerson, nSpectralCoeff, nSamplesPerMode):
    modeLocs = {}
    locs = {}
    data = []
    personIndex = []
    
    for thisMode in pyro.plate("modeLoc_loop", nModesPerPerson):
        modeLocName = "modeLoc_" + str(thisMode)
        thisModeLoc = pyro.sample(modeLocName, dist.MultivariateNormal(modeMean, modeCovariance))
        modeLocs[modeLocName] = thisModeLoc
        
        for thisPerson in pyro.plate("person_loop", nPerson):
            locName = "locModePerson_" + str(thisMode) + "_" + str(thisPerson)
            thisLoc = pyro.sample(locName, dist.MultivariateNormal(thisModeLoc, personCovariance))
            locs[locName] = thisLoc

            for thisDataSample in pyro.plate("data_loop", nSamplesPerMode):
                dataName = "dataMPS_" + str(thisMode) + "_" + str(thisPerson) + "_" + str(thisDataSample)
                thisDatum = pyro.sample(dataName, dist.MultivariateNormal(thisLoc, dataCovariance)).numpy()
                data.append(thisDatum)
                personIndex.append(thisPerson)
    
    return modeLocs, locs, np.hstack((np.expand_dims(np.array(personIndex).transpose(),1), np.array(data)))

In [4]:
#pyro.set_rng_seed(1)
pyro.clear_param_store()
modeLocs, locs, data = generateData(nPerson, nModesPerPerson, nSpectralCoeff, nSamplesPerMode)
person = torch.tensor(np.cast['int'](data[:,0]))
samples = torch.tensor(data[:,1:]).float()
print(modeLocs)
print()
print(locs)

pyplot.plot(data[:,1], data[:,2],'.')

{'modeLoc_0': tensor([-2.9559,  7.3163]), 'modeLoc_1': tensor([ -8.6081, -10.8014]), 'modeLoc_2': tensor([ 2.2951, -2.7334])}

{'locModePerson_0_0': tensor([-1.4431,  8.0140]), 'locModePerson_0_1': tensor([-3.2525,  6.9165]), 'locModePerson_1_0': tensor([ -9.4024, -10.5962]), 'locModePerson_1_1': tensor([ -7.9304, -11.4029]), 'locModePerson_2_0': tensor([ 0.8431, -2.0032]), 'locModePerson_2_1': tensor([-0.2921, -3.6744])}


<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fd1a48e65e0>]

In [8]:
@config_enumerate
def model(person, sample):
    
    loc = torch.zeros(size=(nModesPerPerson, nPerson, nSpectralCoeff))
    
    for thisMode in range(nModesPerPerson):
        modeLocName = "modeLoc_" + str(thisMode)
        thisModeLoc = pyro.sample(modeLocName, dist.MultivariateNormal(modeMean, modeCovariance))
        
        for thisPerson in range(nPerson):
            locName = "locModePerson_" + str(thisMode) + "_" + str(thisPerson)
            thisLoc = pyro.sample(locName, dist.MultivariateNormal(thisModeLoc, personCovariance))

            loc[thisMode, thisPerson, :] = thisLoc
    
    with pyro.plate('data_loop', len(person)):
        modeAssignment = pyro.sample('modeAssignment', 
                                     dist.Categorical((1./nModesPerPerson) * torch.ones(nModesPerPerson)))
        
        pyro.sample('obs', 
                    dist.MultivariateNormal(loc[modeAssignment, person, :], dataCovariance), 
                    obs=sample)
            
    
modeLocList = ["modeLoc_" + str(thisMode) for thisMode in range(3)]
locModPersonList = ["locModePerson_" + str(thisMode) + "_" + str(thisPerson) 
                    for thisMode in range(3) for thisPerson in range(2)]


In [19]:
locModPersonList

['locModePerson_0_0',
 'locModePerson_0_1',
 'locModePerson_1_0',
 'locModePerson_1_1',
 'locModePerson_2_0',
 'locModePerson_2_1']

In [15]:
guideFlag = 0

if guideFlag == 0:
    guide = AutoDelta(poutine.block(model, expose=modeLocList + locModPersonList))
elif guideFlag == 1:
    
    def guide(person, sample):

        for thisMode in pyro.plate("modeLoc_loop", nModesPerPerson):
            modeLocName = "modeLoc_" + str(thisMode)
            thisModeLoc = pyro.sample(modeLocName, dist.MultivariateNormal(modeMean, modeCovariance))

            for thisPerson in pyro.plate("person_loop_{}".format(thisMode), nPerson):
                locName = "locModePerson_" + str(thisMode) + "_" + str(thisPerson)
                thisLoc = pyro.sample(locName, dist.MultivariateNormal(thisModeLoc, personCovariance))

                #loc[thisMode][thisPerson] = thisLoc

        for i in pyro.plate("mode_loop", len(person)):
            modeAssignment = pyro.sample("modeAssignment_{}".format(i), 
                                         dist.Categorical((1./nModesPerPerson) * torch.ones(nModesPerPerson)))

In [11]:
pyro.clear_param_store()

svi = SVI(model, 
          guide, 
          Adam({"lr": .05}), 
          loss=TraceEnum_ELBO())

num_iters = 5000 
losses = []

for i in range(num_iters):
    elbo = svi.step(person, samples)
    losses.append(elbo)
    if i % 500 == 0:
        print("Elbo loss: {}".format(elbo))

Elbo loss: 80464.3515625
Elbo loss: 64327.6484375
Elbo loss: 64327.6484375
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125
Elbo loss: 64327.64453125


In [12]:
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

<IPython.core.display.Javascript object>

In [14]:
print(modeLocs)
print()
print(locs)

{'modeLoc_0': tensor([-2.9559,  7.3163]), 'modeLoc_1': tensor([ -8.6081, -10.8014]), 'modeLoc_2': tensor([ 2.2951, -2.7334])}

{'locModePerson_0_0': tensor([-1.4431,  8.0140]), 'locModePerson_0_1': tensor([-3.2525,  6.9165]), 'locModePerson_1_0': tensor([ -9.4024, -10.5962]), 'locModePerson_1_1': tensor([ -7.9304, -11.4029]), 'locModePerson_2_0': tensor([ 0.8431, -2.0032]), 'locModePerson_2_1': tensor([-0.2921, -3.6744])}


In [13]:
list(pyro.get_param_store().items())

[('AutoDelta.modeLoc_0',
  Parameter containing:
  tensor([-3.5666, -2.1417], requires_grad=True)),
 ('AutoDelta.locModePerson_0_0',
  Parameter containing:
  tensor([-3.3386, -1.5307], requires_grad=True)),
 ('AutoDelta.locModePerson_0_1',
  Parameter containing:
  tensor([-3.8302, -2.7741], requires_grad=True)),
 ('AutoDelta.modeLoc_1',
  Parameter containing:
  tensor([-3.5666, -2.1417], requires_grad=True)),
 ('AutoDelta.locModePerson_1_0',
  Parameter containing:
  tensor([-3.3386, -1.5307], requires_grad=True)),
 ('AutoDelta.locModePerson_1_1',
  Parameter containing:
  tensor([-3.8302, -2.7741], requires_grad=True)),
 ('AutoDelta.modeLoc_2',
  Parameter containing:
  tensor([-3.5666, -2.1417], requires_grad=True)),
 ('AutoDelta.locModePerson_2_0',
  Parameter containing:
  tensor([-3.3386, -1.5307], requires_grad=True)),
 ('AutoDelta.locModePerson_2_1',
  Parameter containing:
  tensor([-3.8302, -2.7741], requires_grad=True))]

# MCMC

In [10]:
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=250, warmup_steps=50)
mcmc.run(person, samples)
posterior_samples = mcmc.get_samples()

Sample: 100%|██████████| 300/300 [00:29, 10.15it/s, step size=7.20e-01, acc. prob=0.818]


# Eight schools example

In [None]:


import torch
from torch.distributions import constraints, transforms

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam

#logging.basicConfig(format='%(message)s', level=logging.INFO)

J = 8
y = torch.tensor([28,  8, -3,  7, -1,  1, 18, 12]).type(torch.Tensor)
sigma = torch.tensor([15, 10, 16, 11,  9, 11, 10, 18]).type(torch.Tensor)
data = torch.stack([y, sigma], dim=1)


def model(data):
    y = data[:, 0]
    sigma = data[:, 1]

    with pyro.plate("data", J):
        eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
        mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
        tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))

        theta = mu + tau * eta

        pyro.sample("obs", dist.Normal(theta, sigma), obs=y)


def guide(data):
    loc_eta = torch.randn(J)
    # note that we initialize our scales to be pretty narrow
    scale_eta = 0.1 * torch.rand(J)
    loc_mu = torch.randn(1)
    scale_mu = 0.1 * torch.rand(1)
    loc_logtau = torch.randn(1)
    scale_logtau = 0.1 * torch.rand(1)

    # register learnable params in the param store
    m_eta_param = pyro.param("loc_eta", loc_eta)
    s_eta_param = pyro.param("scale_eta", scale_eta, constraint=constraints.positive)
    m_mu_param = pyro.param("loc_mu", loc_mu)
    s_mu_param = pyro.param("scale_mu", scale_mu, constraint=constraints.positive)
    m_logtau_param = pyro.param("loc_logtau", loc_logtau)
    s_logtau_param = pyro.param("scale_logtau", scale_logtau, constraint=constraints.positive)

    # guide distributions
    dist_eta = dist.Normal(m_eta_param, s_eta_param)
    dist_mu = dist.Normal(m_mu_param, s_mu_param)
    dist_tau = dist.TransformedDistribution(dist.Normal(m_logtau_param, s_logtau_param),
                                            transforms=transforms.ExpTransform())

    with pyro.plate("data", J):
        pyro.sample('eta', dist_eta)
        pyro.sample('mu', dist_mu)
        pyro.sample('tau', dist_tau)



optim = Adam({'lr': .05})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, loss=elbo)

pyro.clear_param_store()
for j in range(1000):
    loss = svi.step(data)
    if j % 100 == 0:
        print("[epoch %04d] loss: %.4f" % (j + 1, loss))

for name, value in pyro.get_param_store().items():
    print(name)
    print(value.detach().cpu().numpy())


# Scratch