In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, HMC, NUTS
from pyro.infer.mcmc.api import MCMC
import pyro.poutine as poutine
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, TracePredictive

from pyro.infer.mcmc.util import predictive
from pyro.distributions.util import sum_rightmost

from torch.autograd import Variable
import matplotlib.pyplot as plt

pyro.set_rng_seed(42)

In [2]:
N = 50 # Size of the dataset
X_data = torch.rand(N,1) # Sampling of N uniformly distributed points
a, b = 10, 5
sigma = 5 
Y_data = a * X_data + b + dist.Normal(loc=0, scale=sigma).sample([N,1]) # Computing Y_data with normal noise

In [3]:
def model(x, y):
    
    A = pyro.sample('A', dist.Normal(0,1))
    B = pyro.sample('B', dist.Normal(0,1))

    sigma = pyro.sample('sigma', dist.Uniform(0,1))    

    prediction = A*x + B
    return pyro.sample("obs", dist.Normal(prediction, sigma),  obs=y)

In [20]:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=60, warmup_steps=0, num_chains=1)

In [21]:
mcmc.run(X_data, Y_data)
mcmc.summary()

sample: 100%|██████████| 60/60 [00:44<00:00,  1.43s/it, step size=3.91e-03, acc. prob=1.000]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A      4.98      5.40      5.03     -1.59     10.76      2.40      3.27
         B      3.77      3.28      4.96     -0.38      8.65      4.44      1.31
     sigma      0.92      0.13      1.00      0.74      1.00      4.18      1.41







In [26]:
from pyro.infer.mcmc.util import predictive
samples = mcmc.get_samples()


In [37]:
samples['B'].shape
samples['A'].shape

torch.Size([60])

{'A': tensor([-1.5919, -1.5729, -1.5629, -1.5461, -1.4656, -1.4639, -1.3832, -1.3660,
         -1.3605, -1.3566, -1.3540, -1.3084, -1.3027, -1.2823, -1.2733, -1.2593,
         -1.2594, -1.2541, -1.2499, -1.2004, -1.1966, -1.1971, -1.1637,  4.8613,
          4.9437,  4.8319,  4.8484,  4.6741,  4.7236,  5.0392,  5.0320,  5.6243,
          6.5540,  6.7215, 11.7084,  9.3853, 10.7626, 10.0524, 10.8032, 10.5241,
         10.5447, 10.4374,  9.9782, 10.7427, 10.7114, 10.2598, 10.3310,  9.9177,
          9.5031, 11.5683, 10.3961, 10.6714, 10.2148,  9.9109, 10.6928, 11.1219,
         10.8143, 10.4378, 10.1244, 10.5912]),
 'B': tensor([-0.4121, -0.3794, -0.3550, -0.3377, -0.2033, -0.1979, -0.0903, -0.0777,
         -0.0690, -0.0693, -0.0683,  0.0191,  0.0279,  0.0708,  0.0771,  0.0963,
          0.1217,  0.1239,  0.1243,  0.2243,  0.2277,  0.2416,  0.3343,  9.4318,
          9.3195,  9.0640,  8.9471,  8.6157,  8.6471,  8.2595,  8.2617,  7.9498,
          7.6205,  7.5153,  4.2795,  5.6118,  5.1550

In [None]:
trace = predictive(model, samples, X_data, Y_data, return_trace=True)

In [27]:
trace.compute_log_prob()

In [28]:
log = trace.nodes['obs']['log_prob']

In [30]:
log.shape

torch.Size([50, 60])

In [48]:
trace.nodes['_INPUT']

{'name': '_INPUT', 'type': 'args', 'args': (tensor([[0.8823],
          [0.9150],
          [0.3829],
          [0.9593],
          [0.3904],
          [0.6009],
          [0.2566],
          [0.7936],
          [0.9408],
          [0.1332],
          [0.9346],
          [0.5936],
          [0.8694],
          [0.5677],
          [0.7411],
          [0.4294],
          [0.8854],
          [0.5739],
          [0.2666],
          [0.6274],
          [0.2696],
          [0.4414],
          [0.2969],
          [0.8317],
          [0.1053],
          [0.2695],
          [0.3588],
          [0.1994],
          [0.5472],
          [0.0062],
          [0.9516],
          [0.0753],
          [0.8860],
          [0.5832],
          [0.3376],
          [0.8090],
          [0.5779],
          [0.9040],
          [0.5547],
          [0.3423],
          [0.6343],
          [0.3644],
          [0.7104],
          [0.9464],
          [0.7890],
          [0.2814],
          [0.7886],
          [0.5895]