In [1]:
import pyro
import pyro.distributions as dist
from pyro.optim.multi import MixedMultiOptimizer
from pyro import poutine

import torch
from torch import tensor
import pandas as pd
import numpy as np



In [2]:
def model(args):
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    loc_prior = args['loc_prior']
    x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
    x2 = pyro.sample('x2',dist.Normal(x1,1))
    x3 = pyro.sample('x3',dist.Normal(x2,1), obs=args['x3_obs'])
  return x1, x2, x3

args={}
args['loc_prior'] = tensor(1.)
args['x3_obs'] = tensor([1., 1.1, 1.2])

model(args)


(tensor([ 1.2195, -0.1973,  1.3594]),
 tensor([0.6021, 0.8084, 0.5671]),
 tensor([1.0000, 1.1000, 1.2000]))

In [3]:
trace = pyro.poutine.trace(model).get_trace(args)
trace

<pyro.poutine.trace_struct.Trace at 0x105510050>

In [4]:
trace = pyro.poutine.trace_logger(model).get_trace(args)

In [5]:
tl = pyro.poutine.trace_logger(model)

In [6]:
tr = tl.get_trace(args)

In [7]:
pyro.clear_param_store()

def guide(args):
  x1_loc = pyro.param('x1_loc',tensor(10.))
  x2_loc = pyro.param('x2_loc',tensor(10.))
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
    x2 = pyro.sample('x2',dist.Normal(x2_loc,1))

g_trace = pyro.poutine.trace(guide).get_trace(args)
params = [g_trace.nodes[name]["value"].unconstrained() for name in g_trace.param_nodes] # value changes during training. need to recompute g_trace
params

[tensor(10., requires_grad=True), tensor(10., requires_grad=True)]

In [8]:
def loss_and_trace(model,guide,args,model_condition_data={}):
  """
  TODO: guide_traces['x1/x2_loc'] history not stored. only most recent value
    also not stored in guide_traces[0].nodes['x1']['fn'].loc
  """
  # http://pyro.ai/examples/effect_handlers.html#Example:-Variational-inference-with-a-Monte-Carlo-ELBO
  conditioned_model = poutine.condition(model, data=model_condition_data) # https://docs.pyro.ai/en/stable/poutine.html#module-pyro.poutine.handlers
  guide_trace = poutine.trace_logger(guide).get_trace(args)
  model_trace = poutine.trace_logger(
          poutine.replay(conditioned_model, trace=guide_trace)
      ).get_trace(args)
  p = model_trace.log_prob_sum() 
  q = guide_trace.log_prob_sum()
  elbo = p - q
  elbo_loss = -elbo
  return elbo_loss, model_trace, guide_trace

def train(model, guide, data):
  adam = pyro.optim.Adam({'lr': 0.1})
  sgd = pyro.optim.SGD({'lr': 0.01})
  optim = MixedMultiOptimizer([(['x1_loc'], adam), (['x2_loc'], sgd)])
  losses = []
  model_traces = []
  guide_traces = []
  x1_locs = []
  x2_locs = []
  for batch in data:
      # this poutine.trace will record all of the parameters that appear in the model and guide
      # during the execution of monte_carlo_elbo
      with poutine.trace() as param_capture:
          # we use poutine.block here so that only parameters appear in the trace above
          with poutine.block(hide_fn=lambda node: node["type"] != "param"):
              elbo_loss, model_trace, guide_trace = loss_and_trace(model, guide, batch, model_condition_data={})
              losses.append(elbo_loss.item())
              model_traces.append(model_trace)
              guide_traces.append(guide_trace)
              x1_locs.append(pyro.param('x1_loc').item())
              x2_locs.append(pyro.param('x2_loc').item())
      params = {name: site['value'].unconstrained()
        for name, site in param_capture.trace.nodes.items()
        if site['type'] == 'param'}
      optim.step(elbo_loss, params)
  return losses, model_traces, guide_traces, x1_locs, x2_locs

In [9]:
losses, model_traces, guide_traces, x1_locs, x2_locs = train(model, guide, data=[args]*100)

In [10]:
guide_traces[:3]

[<pyro.poutine.trace_struct.Trace at 0x105530c10>,
 <pyro.poutine.trace_struct.Trace at 0x12e20f2d0>,
 <pyro.poutine.trace_struct.Trace at 0x10553c1d0>]

In [11]:
guide_traces[0].nodes

OrderedDict([('_INPUT',
              {'name': '_INPUT',
               'type': 'args',
               'args': ({'loc_prior': tensor(1.),
                 'x3_obs': tensor([1.0000, 1.1000, 1.2000])},),
               'kwargs': {}}),
             ('x1_loc',
              {'type': 'param',
               'name': 'x1_loc',
               'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x12e39b850>>,
               'is_observed': False,
               'args': ('x1_loc', tensor(10.)),
               'kwargs': {'constraint': Real(), 'event_dim': None},
               'value': tensor(2.5163, requires_grad=True),
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
               'done': True,
               'stop': False,
               'continuation': None,
               'infer': {}}),
             ('x2_loc',
              {'type': 'param',
               'name': 'x2_loc',
               'fn': <bou

In [12]:
guide_traces[0].nodes['x1_loc']['fn']('x1_loc')

tensor(2.5163)

In [13]:
guide_traces[-1].nodes['x1_loc']['fn']('x1_loc')

tensor(2.5163)