In [1]:
import pyro
import pyro.distributions as dist
from pyro.optim.multi import MixedMultiOptimizer
from pyro import poutine

import torch
from torch import tensor
import pandas as pd
import numpy as np


In [2]:
def model(args):
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    loc_prior = args['loc_prior']
    x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
    x2 = pyro.sample('x2',dist.Normal(x1,1))
    x3 = pyro.sample('x3',dist.Normal(x2,1), obs=args['x3_obs'])
  return x1, x2, x3

args={}
args['loc_prior'] = tensor(1.)
args['x3_obs'] = tensor([1., 1.1, 1.2])

model(args)


(tensor([ 0.8824, -0.1313,  1.0409]),
 tensor([0.4127, 0.3570, 3.0340]),
 tensor([1.0000, 1.1000, 1.2000]))

In [3]:
trace = pyro.poutine.trace(model).get_trace(args)
trace

<pyro.poutine.trace_struct.Trace at 0x10d90f810>

In [4]:
trace = pyro.poutine.trace_logger(model).get_trace(args)

In [5]:
tl = pyro.poutine.trace_logger(model)

In [6]:
tr = tl.get_trace(args)

In [7]:
pyro.clear_param_store()

def guide(args):
  x1_loc = pyro.param('x1_loc',tensor(10.))
  x2_loc = pyro.param('x2_loc',tensor(10.))
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    x1_dist = dist.Normal(x1_loc,1)
    x1 = pyro.sample('x1',x1_dist)
    x2 = pyro.sample('x2',dist.Normal(x2_loc,1))
    print(x1_loc,x1_dist.loc,x1)

g_trace = pyro.poutine.trace(guide).get_trace(args)
params = [g_trace.nodes[name]["value"].unconstrained() for name in g_trace.param_nodes] # value changes during training. need to recompute g_trace
params

tensor(10., requires_grad=True) tensor(10., requires_grad=True) tensor([11.5285,  9.8527,  9.5337], grad_fn=<AddBackward0>)


[tensor(10., requires_grad=True), tensor(10., requires_grad=True)]

In [8]:
def loss_and_trace(model,guide,args,model_condition_data={}):
  """
  TODO: guide_traces['x1/x2_loc'] history not stored. only most recent value
    also not stored in guide_traces[0].nodes['x1']['fn'].loc
  """
  # http://pyro.ai/examples/effect_handlers.html#Example:-Variational-inference-with-a-Monte-Carlo-ELBO
  conditioned_model = poutine.condition(model, data=model_condition_data) # https://docs.pyro.ai/en/stable/poutine.html#module-pyro.poutine.handlers
  guide_trace = poutine.trace_logger(guide).get_trace(args)
  model_trace = poutine.trace_logger(
          poutine.replay(conditioned_model, trace=guide_trace)
      ).get_trace(args)
  p = model_trace.log_prob_sum() 
  q = guide_trace.log_prob_sum()
  elbo = p - q
  elbo_loss = -elbo
  return elbo_loss, model_trace, guide_trace

def train(model, guide, data):
  adam = pyro.optim.Adam({'lr': 0.1})
  sgd = pyro.optim.SGD({'lr': 0.01})
  optim = MixedMultiOptimizer([(['x1_loc'], adam), (['x2_loc'], sgd)])
  losses = []
  model_traces = []
  guide_traces = []
  x1_locs = []
  x2_locs = []
  for batch in data:
      # this poutine.trace will record all of the parameters that appear in the model and guide
      # during the execution of monte_carlo_elbo
      with poutine.trace() as param_capture:
          # we use poutine.block here so that only parameters appear in the trace above
          with poutine.block(hide_fn=lambda node: node["type"] != "param"):
              elbo_loss, model_trace, guide_trace = loss_and_trace(model, guide, batch, model_condition_data={})
              losses.append(elbo_loss.item())
              model_traces.append(model_trace)
              guide_traces.append(guide_trace)
              x1_locs.append(pyro.param('x1_loc').item())
              x2_locs.append(pyro.param('x2_loc').item())
      params = {name: site['value'].unconstrained()
        for name, site in param_capture.trace.nodes.items()
        if site['type'] == 'param'}
      optim.step(elbo_loss, params)
  return losses, model_traces, guide_traces, x1_locs, x2_locs

In [9]:
losses, model_traces, guide_traces, x1_locs, x2_locs = train(model, guide, data=[args]*100)

tensor(10., requires_grad=True) tensor(10., requires_grad=True) tensor([10.1091,  9.3309,  9.8076], grad_fn=<AddBackward0>)
tensor(9.9000, requires_grad=True) tensor(9.9000, requires_grad=True) tensor([ 9.8543, 10.0657, 10.6622], grad_fn=<AddBackward0>)
tensor(9.7999, requires_grad=True) tensor(9.7999, requires_grad=True) tensor([10.2894, 11.0416, 11.2399], grad_fn=<AddBackward0>)
tensor(9.6995, requires_grad=True) tensor(9.6995, requires_grad=True) tensor([11.1328, 11.3302,  8.5887], grad_fn=<AddBackward0>)
tensor(9.5991, requires_grad=True) tensor(9.5991, requires_grad=True) tensor([ 8.5364,  8.2282, 11.7881], grad_fn=<AddBackward0>)
tensor(9.4989, requires_grad=True) tensor(9.4989, requires_grad=True) tensor([ 8.1829, 10.6106,  9.9618], grad_fn=<AddBackward0>)
tensor(9.3985, requires_grad=True) tensor(9.3985, requires_grad=True) tensor([10.6386, 10.5825,  9.2441], grad_fn=<AddBackward0>)
tensor(9.2979, requires_grad=True) tensor(9.2979, requires_grad=True) tensor([ 8.7102,  8.4640, 

tensor(3.2894, requires_grad=True) tensor(3.2894, requires_grad=True) tensor([2.5061, 2.9386, 3.6717], grad_fn=<AddBackward0>)
tensor(3.2447, requires_grad=True) tensor(3.2447, requires_grad=True) tensor([2.7557, 4.0397, 3.6361], grad_fn=<AddBackward0>)
tensor(3.1991, requires_grad=True) tensor(3.1991, requires_grad=True) tensor([1.9799, 4.2611, 3.8760], grad_fn=<AddBackward0>)
tensor(3.1524, requires_grad=True) tensor(3.1524, requires_grad=True) tensor([2.5879, 2.8925, 3.9470], grad_fn=<AddBackward0>)
tensor(3.1065, requires_grad=True) tensor(3.1065, requires_grad=True) tensor([2.4013, 6.3705, 4.1411], grad_fn=<AddBackward0>)
tensor(3.0579, requires_grad=True) tensor(3.0579, requires_grad=True) tensor([0.9394, 1.7866, 4.4502], grad_fn=<AddBackward0>)
tensor(3.0130, requires_grad=True) tensor(3.0130, requires_grad=True) tensor([3.5117, 2.4233, 3.1102], grad_fn=<AddBackward0>)
tensor(2.9676, requires_grad=True) tensor(2.9676, requires_grad=True) tensor([2.0787, 3.1578, 4.0806], grad_fn=

In [11]:
guide_traces[:3]

[<pyro.poutine.trace_struct.Trace at 0x136776590>,
 <pyro.poutine.trace_struct.Trace at 0x136776bd0>,
 <pyro.poutine.trace_struct.Trace at 0x1367a1510>]

In [None]:
fn = guide_traces[0].nodes['x1']['fn']
fn

In [None]:
fn.__self__._params, fn.__self__._constraints

In [None]:
guide_traces[0].nodes['x1_loc']

In [12]:
[guide_traces[i].nodes['x1_loc'] for i in [0,1]]

[{'type': 'param',
  'name': 'x1_loc',
  'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x136776e90>>,
  'is_observed': False,
  'args': ('x1_loc', tensor(10.)),
  'kwargs': {'constraint': Real(), 'event_dim': None},
  'value': tensor(2.5326, requires_grad=True),
  'scale': 1.0,
  'mask': None,
  'cond_indep_stack': (),
  'done': True,
  'stop': False,
  'continuation': None,
  'infer': {}},
 {'type': 'param',
  'name': 'x1_loc',
  'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x136776dd0>>,
  'is_observed': False,
  'args': ('x1_loc', tensor(10.)),
  'kwargs': {'constraint': Real(), 'event_dim': None},
  'value': tensor(2.5326, requires_grad=True),
  'scale': 1.0,
  'mask': None,
  'cond_indep_stack': (),
  'done': True,
  'stop': False,
  'continuation': None,
  'infer': {}}]

In [14]:
[guide_traces[i].nodes['x1'] for i in [0,-1]]

[{'type': 'sample',
  'name': 'x1',
  'fn': Normal(loc: torch.Size([3]), scale: torch.Size([3])),
  'is_observed': False,
  'args': (),
  'kwargs': {},
  'value': tensor([10.1091,  9.3309,  9.8076], grad_fn=<AddBackward0>),
  'infer': {},
  'scale': 1.0,
  'mask': None,
  'cond_indep_stack': (CondIndepStackFrame(name='mini_batch', dim=-1, size=3, counter=0),),
  'done': True,
  'stop': True,
  'continuation': None,
  'log_prob_sum': tensor(-3.0051, grad_fn=<SumBackward0>)},
 {'type': 'sample',
  'name': 'x1',
  'fn': Normal(loc: torch.Size([3]), scale: torch.Size([3])),
  'is_observed': False,
  'args': (),
  'kwargs': {},
  'value': tensor([1.5169, 2.6776, 2.4642], grad_fn=<AddBackward0>),
  'infer': {},
  'scale': 1.0,
  'mask': None,
  'cond_indep_stack': (CondIndepStackFrame(name='mini_batch', dim=-1, size=3, counter=0),),
  'done': True,
  'stop': True,
  'continuation': None,
  'log_prob_sum': tensor(-3.3209, grad_fn=<SumBackward0>)}]

In [None]:
guide_traces[0].nodes

In [None]:
guide_traces[1].nodes['x1_loc']

In [None]:
guide_traces[0].nodes['x1_loc']['fn']('x1_loc')

In [None]:
guide_traces[-1].nodes['x1_loc']['fn']('x1_loc')

In [10]:
[guide_traces[i].nodes['x1']['fn'].loc for i in range(0,100,10)]

[tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>),
 tensor([2.5326, 2.5326, 2.5326], grad_fn=<AsStridedBackward0>)]