In [2]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

## 分布のshapes: batch_shape and event_shape

- batch_shape
    - 独立した分布の数を表す
- event_shape
    - 確率変数の次元

## サンプルのshapes

サンプルのshapeは分布のbatch_shapeと分布のevent_shapeの和

#### 例1: 分布のevent_shapeが0

つまり、0つの分布について考える

分布

In [3]:
d = Bernoulli(1/2)
print(d.batch_shape)
print(d.event_shape)
print(d.shape)
# 分布の数が0

torch.Size([])
torch.Size([])
<bound method TorchDistributionMixin.shape of Bernoulli(probs: 0.5)>


サンプリング

In [4]:
x = d.sample()
print(x.shape)
print(d.log_prob(x).shape)

torch.Size([])
torch.Size([])


#### 例2: 分布のbatch_shapeが1

In [5]:
d = Bernoulli(0.5 * torch.ones(1))
print(d.batch_shape)
print(d.event_shape)
print(d.shape)

torch.Size([1])
torch.Size([])
<bound method TorchDistributionMixin.shape of Bernoulli(probs: tensor([0.5000]))>


In [6]:
# 別の定義の仕方
d = Bernoulli(0.5).expand([1])
print(d.batch_shape)
print(d.event_shape)
print(d.shape)

torch.Size([1])
torch.Size([])
<bound method TorchDistributionMixin.shape of Bernoulli(probs: tensor([0.5000]))>


サンプリング

In [7]:
x = d.sample()
print(x.shape, ":", x)
print(d.log_prob(x).shape, ":",d.log_prob(x))

torch.Size([1]) : tensor([0.])
torch.Size([1]) : tensor([-0.6931])


#### 例3: 分布のbatch_shapeが1, event_shapeが1

In [8]:
# 多変量(event_shapeが1)にするのにパラメータが一つ必要なのでtorch.ones(1, 1)
d = Bernoulli(0.5 * torch.ones(1, 1)).to_event(1)
print(d.batch_shape)
print(d.event_shape)
print(d.shape)

torch.Size([1])
torch.Size([1])
<bound method TorchDistributionMixin.shape of Independent(Bernoulli(probs: tensor([[0.5000]])), 1)>


In [9]:
x = d.sample()
print(x.shape, ":", x)
print(d.log_prob(x).shape, ":",d.log_prob(x))

torch.Size([1, 1]) : tensor([[1.]])
torch.Size([1]) : tensor([-0.6931])


### 例4: BayesianRegressionのlinear.weight

In [10]:
in_features, out_features = 3, 2
d = Normal(0., 1.).expand([in_features, out_features]).to_event(1)
print(d.batch_shape)
print(d.event_shape)
print(d.shape)

torch.Size([3])
torch.Size([2])
<bound method TorchDistributionMixin.shape of Independent(Normal(loc: torch.Size([3, 2]), scale: torch.Size([3, 2])), 1)>


In [11]:
x = d.sample()
print(x.shape, ":", x)
print(d.log_prob(x).shape, ":",torch.exp(d.log_prob(x)))

torch.Size([3, 2]) : tensor([[ 0.4296, -2.0161],
        [ 0.1983, -1.8853],
        [ 1.7071, -0.4525]])
torch.Size([3]) : tensor([0.0190, 0.0264, 0.0335])


### 例5: 多変量混合分布

In [12]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
print(d.batch_shape)
print(d.event_shape)
print(d.shape)

torch.Size([])
torch.Size([3])
<bound method TorchDistributionMixin.shape of MultivariateNormal(loc: torch.Size([3]), covariance_matrix: torch.Size([3, 3]))>


In [13]:
x = d.sample()
print(x.shape, ":", x)
print(d.log_prob(x).shape, ":",torch.exp(d.log_prob(x)))

torch.Size([3]) : tensor([0.2313, 0.3353, 0.0639])
torch.Size([]) : tensor(0.0583)


## It is always safe to assume dependence

In [14]:
with pyro.plate("x_plate", 10):
    x = pyro.sample("x", Normal(0, 1))  # .expand([10]) is automatic
    assert x.shape == (10,)

`pyro.plate`を使うと独立性を表現出来る

In [18]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_shape == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1))
        z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())

In [22]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
     log_prob       |    
       b dist       | 2  
        value       | 2  
     log_prob       |    
 c_plate dist       |    
        value     2 |    
     log_prob       |    
       c dist     2 |    
        value     2 |    
     log_prob     2 |    
 d_plate dist       |    
        value     3 |    
     log_prob       |    
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |    
  x_axis dist       |    
        value     3 |    
     log_prob       |    
  y_axis dist       |    
        value     2 |    
     log_prob       |    
       x dist   3 1 |    
        value   3 1 |    
     log_prob   3 1 |    
       y dist 2 1 1 |    
        value 2 1 1 |    
     log_prob 2 1 1 |    
      xy dist 2 3 1 |    
        value 2 3 1 |    
     log_prob 2 3 1 |    
       z dist 2 3 1 | 5  
        value 2 3 1 | 5  
     log_pro