In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
import scipy
import matplotlib as mpl
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.distributions import Categorical, Normal, Multinomial, Binomial, MultivariateNormal, Beta, constraints
from pyro.distributions.torch import Bernoulli
import pyro.infer as infer
from pyro.infer import TraceEnum_ELBO, Trace_ELBO, config_enumerate
from pyro import poutine
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.messenger import Messenger

plt.style.use('seaborn-v0_8')


In [119]:
def model():
    p = pyro.param("p", torch.arange(6.) / 6)
    
    a = pyro.sample("a", Categorical(torch.ones(6, dtype=torch.float) / 6.))
    b = pyro.sample("b", Bernoulli(p[a]))
    locs = torch.tensor([0.2, 0.9])
    with pyro.plate("c_plate", 4):
        c = pyro.sample("c", Bernoulli(0.5))
        e = pyro.sample("e", Normal(locs[c.long()], 1.0))
    
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3, 4), torch.ones(3, 4)).to_event(1))

pyro.set_rng_seed(0)
pyro.clear_param_store()
model = config_enumerate(model)
trace = poutine.trace(poutine.enum(model, first_available_dim=-2)).get_trace()
trace.nodes["c"]

{'type': 'sample',
 'name': 'c',
 'fn': Bernoulli(probs: torch.Size([4])),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[[[0.]]],
 
 
         [[[1.]]]]),
 'infer': {'enumerate': 'parallel',
  'expand': False,
  '_enumerate_dim': -4,
  '_dim_to_id': {-4: 2, -2: 0, -3: 1}},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (CondIndepStackFrame(name='c_plate', dim=-1, size=4, counter=0),),
 'done': True,
 'stop': False,
 'continuation': None}

In [120]:
trace.compute_log_prob()

In [121]:
trace.format_shapes().split("\n")

['Trace Shapes:            ',
 ' Param Sites:            ',
 '            p         6  ',
 'Sample Sites:            ',
 '       a dist         |  ',
 '        value     6 1 |  ',
 '     log_prob     6 1 |  ',
 '       b dist     6 1 |  ',
 '        value   2 1 1 |  ',
 '     log_prob   2 6 1 |  ',
 ' c_plate dist         |  ',
 '        value       4 |  ',
 '     log_prob         |  ',
 '       c dist       4 |  ',
 '        value 2 1 1 1 |  ',
 '     log_prob 2 1 1 4 |  ',
 '       e dist 2 1 1 4 |  ',
 '        value 2 1 1 4 |  ',
 '     log_prob 2 1 1 4 |  ',
 ' d_plate dist         |  ',
 '        value       3 |  ',
 '     log_prob         |  ',
 '       d dist       3 | 4',
 '        value       3 | 4',
 '     log_prob       3 |  ']

In [122]:
trace.nodes["d"]

{'type': 'sample',
 'name': 'd',
 'fn': Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[-0.7193, -0.4033, -0.5966,  0.1820],
         [-0.8567,  1.1006, -1.0712,  0.1227],
         [-0.5663,  0.3731, -0.8920, -1.5091]]),
 'infer': {'_dim_to_id': {-2: 0, -3: 1, -4: 2}},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (CondIndepStackFrame(name='d_plate', dim=-1, size=3, counter=0),),
 'done': True,
 'stop': False,
 'continuation': None,
 'unscaled_log_prob': tensor([-4.2103, -5.2296, -5.4422]),
 'log_prob': tensor([-4.2103, -5.2296, -5.4422]),
 'log_prob_sum': tensor(-14.8822)}

In [123]:
trace.nodes["a"]

{'type': 'sample',
 'name': 'a',
 'fn': Categorical(probs: torch.Size([6]), logits: torch.Size([6])),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[0],
         [1],
         [2],
         [3],
         [4],
         [5]]),
 'infer': {'enumerate': 'parallel',
  'expand': False,
  '_enumerate_dim': -2,
  '_dim_to_id': {-2: 0}},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (),
 'done': True,
 'stop': False,
 'continuation': None,
 'unscaled_log_prob': tensor([[-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918]]),
 'log_prob': tensor([[-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918]]),
 'log_prob_sum': tensor(-10.7506)}

In [124]:
trace.nodes["b"]["value"]

tensor([[[0.]],

        [[1.]]])

In [125]:
trace.nodes["c"]["value"]

tensor([[[[0.]]],


        [[[1.]]]])

In [126]:
help(poutine.trace)

Help on function trace in module pyro.poutine.handlers:

trace(fn=None, *args, **kwargs)
    Convenient wrapper of :class:`~pyro.poutine.trace_messenger.TraceMessenger` 
    
    
    Return a handler that records the inputs and outputs of primitive calls
    and their dependencies.
    
    Consider the following Pyro program:
    
        >>> def model(x):
        ...     s = pyro.param("s", torch.tensor(0.5))
        ...     z = pyro.sample("z", dist.Normal(x, s))
        ...     return z ** 2
    
    We can record its execution using ``trace``
    and use the resulting data structure to compute the log-joint probability
    of all of the sample sites in the execution or extract all parameters.
    
        >>> trace = pyro.poutine.trace(model).get_trace(0.0)
        >>> logp = trace.log_prob_sum()
        >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
    
    :param fn: a stochastic function (callable containing Pyro primitive calls)
    :

In [127]:
trace.nodes["a"]

{'type': 'sample',
 'name': 'a',
 'fn': Categorical(probs: torch.Size([6]), logits: torch.Size([6])),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[0],
         [1],
         [2],
         [3],
         [4],
         [5]]),
 'infer': {'enumerate': 'parallel',
  'expand': False,
  '_enumerate_dim': -2,
  '_dim_to_id': {-2: 0}},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (),
 'done': True,
 'stop': False,
 'continuation': None,
 'unscaled_log_prob': tensor([[-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918]]),
 'log_prob': tensor([[-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918],
         [-1.7918]]),
 'log_prob_sum': tensor(-10.7506)}

In [128]:
np.log(1/6.)

-1.791759469228055

In [129]:
trace.nodes["b"]

{'type': 'sample',
 'name': 'b',
 'fn': Bernoulli(probs: torch.Size([6, 1]), logits: torch.Size([6, 1])),
 'is_observed': False,
 'args': (),
 'kwargs': {},
 'value': tensor([[[0.]],
 
         [[1.]]]),
 'infer': {'enumerate': 'parallel',
  'expand': False,
  '_enumerate_dim': -3,
  '_dim_to_id': {-3: 1, -2: 0}},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (),
 'done': True,
 'stop': False,
 'continuation': None,
 'unscaled_log_prob': tensor([[[-1.1921e-07],
          [-1.8232e-01],
          [-4.0547e-01],
          [-6.9315e-01],
          [-1.0986e+00],
          [-1.7918e+00]],
 
         [[-1.5942e+01],
          [-1.7918e+00],
          [-1.0986e+00],
          [-6.9315e-01],
          [-4.0547e-01],
          [-1.8232e-01]]], grad_fn=<NegBackward0>),
 'log_prob': tensor([[[-1.1921e-07],
          [-1.8232e-01],
          [-4.0547e-01],
          [-6.9315e-01],
          [-1.0986e+00],
          [-1.7918e+00]],
 
         [[-1.5942e+01],
          [-1.7918e+00],
         

In [130]:
trace.nodes["b"]["log_prob"].size()

torch.Size([2, 6, 1])

In [131]:
trace.nodes["a"]["log_prob"].exp()

tensor([[0.1667],
        [0.1667],
        [0.1667],
        [0.1667],
        [0.1667],
        [0.1667]])

In [132]:
np.log(0.1667)

-1.7915594892253888

In [133]:
np.log(1 - 0.1667)

-0.1823615575939759

In [134]:
trace.nodes["e"]["fn"].loc

tensor([[[[0.2000, 0.2000, 0.2000, 0.2000]]],


        [[[0.9000, 0.9000, 0.9000, 0.9000]]]])

In [135]:
trace.nodes["b"]["fn"].logits

tensor([[-15.9424],
        [ -1.6094],
        [ -0.6931],
        [  0.0000],
        [  0.6931],
        [  1.6094]], grad_fn=<SubBackward0>)

In [136]:
trace.nodes["b"]["fn"].probs

tensor([[0.0000],
        [0.1667],
        [0.3333],
        [0.5000],
        [0.6667],
        [0.8333]], grad_fn=<IndexBackward0>)

In [137]:
(1 - trace.nodes["b"]["fn"].probs).log()

tensor([[ 0.0000],
        [-0.1823],
        [-0.4055],
        [-0.6931],
        [-1.0986],
        [-1.7918]], grad_fn=<LogBackward0>)

In [141]:
trace.nodes["b"]["log_prob"][0]

tensor([[-1.1921e-07],
        [-1.8232e-01],
        [-4.0547e-01],
        [-6.9315e-01],
        [-1.0986e+00],
        [-1.7918e+00]], grad_fn=<SelectBackward0>)

In [142]:
(trace.nodes["b"]["fn"].probs).log()

tensor([[   -inf],
        [-1.7918],
        [-1.0986],
        [-0.6931],
        [-0.4055],
        [-0.1823]], grad_fn=<LogBackward0>)

In [143]:
trace.nodes["b"]["log_prob"][1]

tensor([[-15.9424],
        [ -1.7918],
        [ -1.0986],
        [ -0.6931],
        [ -0.4055],
        [ -0.1823]], grad_fn=<SelectBackward0>)