In [97]:
import torch
import pyro
import numpy as np

pyro.set_rng_seed(101)

In [98]:
loc = 0
scale = 1.
normal = torch.distributions.Normal(loc, scale)
x = normal.rsample() # draw a sample from N(0,1)
print('sample', x)
print('log prob', normal.log_prob(x)) # score the sample from N(0,1)

sample tensor(-1.3905)
log prob tensor(-1.8857)


In [99]:
def torch_weather():
    cloudy = torch.distributions.Bernoulli(.3).sample()
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
    
    return cloudy, temp.item()


In [100]:
x = pyro.sample('my_sample', pyro.distributions.Normal(loc, scale))
print(x)

tensor(-0.8152)


In [101]:
def weather():
    cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))

    return cloudy, temp.item()

In [102]:
for _ in range(3):
    print(weather())

('cloudy', 64.5440444946289)
('sunny', 94.37557983398438)
('sunny', 72.5186767578125)


In [103]:
def ice_cream_sales():
    cloudy, temp = weather()
    expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.
    ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0))

    return ice_cream.item()

In [104]:
def geometric(p, t=None):
    if t is None:
        t = 0
    x = pyro.sample(f'x_{t}', pyro.distributions.Bernoulli(p))
    if x.item() == 1:
        return 0
    else:
        return 1 + geometric(p, 1 + t)

print(geometric(.5))

0


In [105]:
def normal_product(loc, scale):
    z1 = pyro.sample('z1', pyro.distributions.Normal(loc, scale))
    z2 = pyro.sample('z2', pyro.distributions.Normal(loc, scale))
    y = z1 * z2

    return y


def make_normal_normal():
    mu_latent = pyro.sample('mu_latent', pyro.distributions.Normal(0, 1))
    fn = lambda scale: normal_product(mu_latent, scale)

    return fn

print(make_normal_normal()(1))

tensor(2.1493)
