In [1]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

### Distributions shapes: batch_shape and event_shape

- 分布の`shape`は`batch_shape`と`event_shape`の和
    - `sample_shape`: 分布からサンプリングを行った
    - `batch_shape`: $\mu$や$z$パラメータ(条件付き確率変数)を扱う次元
    - `event_shape`: $x$: 確率変数

$$
d = [
[{\rm Bern}(x|\mu=0.5), {\rm Bern}(x|\mu=0.1)],
[{\rm Bern}(x|\mu=0.3), {\rm Bern}(x|\mu=1.0)]
]
$$

$x$の例
$$
x = [[1, 0],[0,1]]
$$

In [55]:
d = Bernoulli(torch.tensor([[0.5, 0.1], [0.3, 1.0]]))
print(d.batch_shape)
assert d.batch_shape == (2, 2)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (2, 2)
assert d.log_prob(x).shape == (2, 2)

torch.Size([2, 2])


expandも使える

In [51]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

多次元正規分布はevent_shapeを持つ

In [63]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

In [64]:
torch.eye(3, 3)

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])