In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn, optim
import numpy as np
import pyro
import pyro.distributions as dist
from pyro.distributions import Categorical, Normal, Multinomial, Binomial, MultivariateNormal, Beta
from pyro.distributions.torch import Bernoulli
import pyro.infer as infer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pyro.set_rng_seed(0)

In [3]:
pyro.__version__

'1.8.4'

In [4]:
dist.Bernoulli(0.5)

Bernoulli(probs: 0.5)

In [5]:
dist.Bernoulli(0.5).sample()

tensor(1.)

In [6]:
dist.Bernoulli(0.5).sample()

tensor(0.)

In [7]:
d = dist.Bernoulli(0.5)

In [8]:
d.sample()

tensor(1.)

In [9]:
d.sample((10,))

tensor([1., 1., 0., 1., 0., 1., 0., 1., 1., 1.])

In [10]:
d.log_prob(d.sample())

tensor(-0.6931)

In [11]:
d = dist.Bernoulli(0.9)

In [12]:
s = d.sample()
s

tensor(1.)

In [13]:
d.log_prob(s)

tensor(-0.1054)

In [14]:
s = d.sample()
s

tensor(1.)

In [15]:
d.log_prob(torch.tensor(0.))

tensor(-2.3026)

In [16]:
np.log(0.9)

-0.10536051565782628

In [17]:
np.log(0.1)

-2.3025850929940455

In [18]:
pyro.set_rng_seed(0)
d = dist.Bernoulli(torch.empty(2, 3).uniform_(0, 1))
d, d.probs

(Bernoulli(probs: torch.Size([2, 3])),
 tensor([[0.4963, 0.7682, 0.0885],
         [0.1320, 0.3074, 0.6341]]))

In [19]:
x = d.sample()
x, x.shape

(tensor([[1., 0., 0.],
         [0., 0., 1.]]),
 torch.Size([2, 3]))

In [20]:
d.batch_shape

torch.Size([2, 3])

In [21]:
d.log_prob(x)

tensor([[-0.7007, -1.4620, -0.0926],
        [-0.1416, -0.3673, -0.4556]])

In [22]:
torch.log(x * d.probs + (1 - x) * (1-d.probs))

tensor([[-0.7007, -1.4620, -0.0926],
        [-0.1416, -0.3673, -0.4556]])

In [23]:
pyro.set_rng_seed(10)
probs = torch.empty(2, 3).uniform_(0, 1)
d = dist.Bernoulli(probs).to_event(1)
d, probs

(Independent(Bernoulli(probs: torch.Size([2, 3])), 1),
 tensor([[0.4581, 0.4829, 0.3125],
         [0.6150, 0.2139, 0.4118]]))

In [24]:
d.batch_shape

torch.Size([2])

In [25]:
d.event_shape

torch.Size([3])

In [26]:
x = d.sample()
x

tensor([[0., 0., 0.],
        [1., 0., 0.]])

In [27]:
d.log_prob(x)

tensor([-1.6468, -1.2576])

In [28]:
np.log(1 - np.array([0.4581, 0.4829, 0.3125])).sum()

-1.6468862454314406

In [29]:
np.log(np.array([0.6150, 1- 0.2139, 1-0.4118])).sum()

-1.2574925322202877

In [30]:
pyro.set_rng_seed(10)
probs = torch.empty(15, 8).uniform_(0, 1)
d = dist.Bernoulli(probs).to_event(1)
d, probs

(Independent(Bernoulli(probs: torch.Size([15, 8])), 1),
 tensor([[0.4581, 0.4829, 0.3125, 0.6150, 0.2139, 0.4118, 0.6938, 0.9693],
         [0.6178, 0.3304, 0.5479, 0.4440, 0.7041, 0.5573, 0.6959, 0.9849],
         [0.2924, 0.4823, 0.6150, 0.4967, 0.4521, 0.0575, 0.0687, 0.0501],
         [0.0108, 0.0343, 0.1212, 0.0490, 0.0310, 0.7192, 0.8067, 0.8379],
         [0.7694, 0.6694, 0.7203, 0.2235, 0.9502, 0.4655, 0.9314, 0.6533],
         [0.8914, 0.8988, 0.3955, 0.3546, 0.5752, 0.4787, 0.5782, 0.7536],
         [0.1093, 0.4771, 0.1076, 0.9829, 0.1483, 0.5956, 0.3634, 0.7842],
         [0.5017, 0.4497, 0.8660, 0.9567, 0.1371, 0.0177, 0.5417, 0.6575],
         [0.6141, 0.9619, 0.7244, 0.2700, 0.1576, 0.8879, 0.9792, 0.2627],
         [0.1800, 0.6750, 0.1424, 0.3790, 0.0055, 0.6368, 0.3295, 0.2203],
         [0.1821, 0.3241, 0.7375, 0.5283, 0.9306, 0.3213, 0.3537, 0.9894],
         [0.0231, 0.3032, 0.1898, 0.9811, 0.7662, 0.3325, 0.2877, 0.6533],
         [0.3310, 0.6552, 0.2441, 0.1906, 0.

In [31]:
d.sample()

tensor([[0., 0., 1., 1., 0., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 0., 0., 1., 0., 1., 1.],
        [1., 0., 0., 1., 0., 1., 0., 1.],
        [0., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 1., 0.],
        [1., 1., 0., 1., 0., 0., 1., 1.],
        [1., 0., 1., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0., 0., 1.],
        [0., 1., 1., 0., 0., 1., 1., 1.],
        [0., 1., 0., 1., 0., 0., 1., 1.],
        [1., 0., 1., 0., 0., 0., 0., 1.]])

In [32]:
d.log_prob(d.sample())

tensor([-3.9376, -4.7065, -3.8444, -2.4066, -4.2979, -4.4157, -4.0161, -2.8469,
        -3.9206, -3.5516, -3.9249, -4.2868, -3.2856, -5.0895, -2.6148])

In [33]:
d = dist.MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)

In [34]:
d.sample()

tensor([-0.1316,  0.8379,  2.6584])

In [35]:
d = dist.MultivariateNormal(torch.zeros(10, 3), torch.eye(3, 3).expand(10, -1, -1)).to_event(1)

In [36]:
d.sample()

tensor([[ 0.6050,  1.3021, -0.1537],
        [ 0.2865,  0.2404,  1.6265],
        [-0.0273, -0.7515, -0.5487],
        [-0.3148, -0.4927,  1.2738],
        [ 2.2051,  1.4851, -0.7850],
        [-1.3380, -0.2595, -0.2028],
        [ 1.5568, -0.5618, -0.7364],
        [ 0.3485,  2.1104, -1.1749],
        [ 0.6335, -1.0158, -1.6675],
        [-1.5590,  1.4037,  0.7496]])

In [37]:
d.event_shape

torch.Size([10, 3])

In [38]:
d.batch_shape

torch.Size([])

In [39]:
x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))
x

tensor([ 0.3164,  0.9872,  0.8355, -1.1924, -0.6644, -0.6571,  0.4060,  0.2794,
         1.4254, -1.5937])

In [40]:
dist.Normal(0, 1).expand([2])

Normal(loc: torch.Size([2]), scale: torch.Size([2]))

In [41]:
dist.Normal(0, 1).expand([2]).batch_shape

torch.Size([2])

In [42]:
dist.Normal(0, 1).expand([2]).event_shape

torch.Size([])

In [43]:
pyro.set_rng_seed(50)
pyro.clear_param_store()
data = torch.bernoulli(torch.clip(torch.randn(10), min=0, max=1))
probs = pyro.param('p', lambda:torch.empty((10, )).uniform_(0, 1.))
print("data is: ", data)
print("probs is: ", probs.data)
with pyro.plate('x', len(data), subsample_size=5) as ind:
    label = data[ind]
    print("ind: ", ind)
    pred = probs[ind]
    print("pred ", pred)
    x = pyro.sample('x', dist.Bernoulli(pred))
    print('x ', x)



data is:  tensor([0., 0., 1., 1., 0., 0., 1., 0., 0., 0.])
probs is:  tensor([0.3490, 0.1953, 0.2792, 0.2526, 0.3792, 0.7686, 0.6907, 0.7526, 0.1184,
        0.8699])
ind:  tensor([0, 6, 5, 2, 4])
pred  tensor([0.3490, 0.6907, 0.7686, 0.2792, 0.3792], grad_fn=<IndexBackward0>)
x  tensor([0., 1., 0., 0., 0.])


In [44]:
pyro.set_rng_seed(50)
pyro.clear_param_store()
data = torch.bernoulli(torch.clip(torch.randn(10), min=0, max=1))
probs = pyro.param('p', lambda:torch.empty((10, )).uniform_(0, 1.))
with pyro.plate('x', len(data), subsample_size=5) as ind:
    label = data[ind]
    print("Label: ", label)
    pred = probs[ind]
    print("pred ", pred)
    d = dist.Bernoulli(pred)
    x = pyro.sample('x', d, obs=label)
    print('x ', x, torch.all(x == label))
    print(d.log_prob(x))

Label:  tensor([0., 1., 0., 1., 0.])
pred  tensor([0.3490, 0.6907, 0.7686, 0.2792, 0.3792], grad_fn=<IndexBackward0>)
x  tensor([0., 1., 0., 1., 0.]) tensor(True)
tensor([-0.4293, -0.3700, -1.4635, -1.2757, -0.4768], grad_fn=<NegBackward0>)


In [50]:
pyro.set_rng_seed(10)
pyro.clear_param_store()
p = pyro.param("p", torch.arange(6.) / 6)
locs = pyro.param("locs", torch.tensor([-1., 1.]))
a = pyro.sample("a", dist.Categorical(torch.ones(6) / 6))
b = pyro.sample("b", dist.Bernoulli(p[a]))

with pyro.plate("c_plate", 4):
    c = pyro.sample("c", dist.Bernoulli(0.3))
    with pyro.plate("d_plate", 5):
        d = pyro.sample("d",  Bernoulli(0.4))
        e_loc = locs[d.long()].unsqueeze(-1)
        e_scale = torch.arange(1., 8.)
        e = pyro.sample("e", Normal(e_loc, e_scale)
                        .to_event(1))


print('a ', a)
print('b ', b)
print('c ', c)
print('d ', d)
print('e_loc ', e_loc.size())
print('e ', e.size())

a  tensor(2)
b  tensor(0.)
c  tensor([0., 0., 1., 0.])
d  tensor([[0., 0., 0., 1.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 1., 1., 1.]])
e_loc  torch.Size([5, 4, 1])
e  torch.Size([5, 4, 7])


In [53]:
pyro.set_rng_seed(0)
pyro.clear_param_store()
p = 0.65

x = pyro.sample("x", Bernoulli(p).expand([10]), )
x

tensor([1., 0., 1., 1., 1., 1., 1., 0., 1., 1.])

In [54]:
pyro.set_rng_seed(0)
pyro.clear_param_store()
p = 0.65

with pyro.plate('data', 10):
    x = pyro.sample("x", Bernoulli(p) )
    print(x)

tensor([1., 0., 1., 1., 1., 1., 1., 0., 1., 1.])


In [57]:
pyro.set_rng_seed(0)
pyro.clear_param_store()
p = 0.65

with pyro.plate('data', 10):
    x = pyro.sample("x", Bernoulli(torch.zeros(10).float().fill_(p)) )
    print(x)

tensor([1., 0., 1., 1., 1., 1., 1., 0., 1., 1.])


In [56]:
pyro.set_rng_seed(0)
pyro.clear_param_store()
p = 0.65

with pyro.plate('data', 10):
    x = pyro.sample("x", Bernoulli(p).expand([10]).to_event(1) )
    print(x)

tensor([[1., 0., 1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 0., 1., 1.],
        [0., 0., 1., 0., 1., 1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 1., 1.],
        [0., 0., 0., 1., 0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 0., 0., 1., 1., 1., 0., 0.],
        [1., 1., 0., 0., 0., 0., 1., 1., 1., 1.],
        [1., 0., 0., 1., 1., 1., 1., 1., 0., 1.],
        [0., 0., 0., 0., 1., 1., 0., 1., 1., 1.]])
