# Contrast pyro between TFP

- https://www.upgrad.com/blog/statistical-programming-in-machine-learning/

In [32]:
import pyro
import torch

import matplotlib.pyplot as plt
import numpy as np

# distributions [pyro]

- Most distributions in pyro are thin wrapper around Pytorch distributions
    - ```torch.distributions.distribution.Dsitribution```
        - https://pytorch.org/docs/stable/distributions.html
    - interface 차이 between 'torch distribution' and 'Pyro'
        - see https://docs.pyro.ai/en/dev/distributions.html


In [33]:
# torch.manual_seed(0)
# np.random.seed(0)
# pyro.set_rng_seed(0)

# Independent distributions [pyro]


**previous example**

In [54]:
# 2-dim, single distribution
locs = torch.FloatTensor([-1,0.5])
cov_mat = torch.FloatTensor([[1,0],
                            [0,1.5]])

mv_normal = pyro.distributions.MultivariateNormal(loc=locs, covariance_matrix=cov_mat)
print(mv_normal)

MultivariateNormal(loc: torch.Size([2]), covariance_matrix: torch.Size([2, 2]))


In [55]:
print("batch shape : ", mv_normal.batch_shape)
print("event shape : ", mv_normal.event_shape)

batch shape :  torch.Size([])
event shape :  torch.Size([2])


In [56]:
# argument of 'log_prob' : 'single' vector of the 2-dimensional random variable
# output : realization of 2-dim random variable

print(mv_normal.log_prob(torch.FloatTensor([-0.2, 1.8])))

tensor(-2.9239)


In [43]:
# 1-dim, 2 distributions(batch_shape=2)

batched_normal = pyro.distributions.Normal(loc=locs, scale=torch.FloatTensor([1,1.5]))
print(batched_normal)

Normal(loc: torch.Size([2]), scale: torch.Size([2]))


In [44]:
# argument of 'log_prob' : each value of a random variable of each distributions in batch.
# output : realization of 2 log_probability values of each distributions in batch

print(batched_normal.log_prob(torch.FloatTensor([-0.2, 1.8])))

tensor([-1.2389, -1.7000])


**independent distribution**

- The independent distribution gives us a way to absorb some or all of the batch dimensions into the event_shape
    - For example, we could use the independent distribution to transform our batched_normal distribution so that it's equivalent to the multivariate normal diag distribution.

In [58]:
batched_normal = pyro.distributions.Normal(loc=locs, scale=torch.FloatTensor([1,1.5]))
print(batched_normal)
print("batch shape of 'batched_normal' : ", batched_normal.batch_shape)
print("event shape of 'batched_normal' : ", batched_normal.event_shape)
print()

# normal의 batch 속 distribution들을 independent distribution으로 변환(transform)
# args 'reinterpreted_batch_ndims' : specify how the batch dimensions should be absorbed into the event space.
## 'reinterpreted_batch_ndims=1' : there is only 1 batch dimension(2) 
independent_normal = pyro.distributions.Independent(batched_normal, reinterpreted_batch_ndims=1)
print(independent_normal)
print("batch shape of 'independent_normal' : ", independent_normal.batch_shape)
print("event shape of 'independent_normal' : ", independent_normal.event_shape)
print()

# wrong case... "log P of "independent(joint D)" != log P of "Multivariate D""
independent_normal2 = batched_normal.independent(reinterpreted_batch_ndims=1)
print(independent_normal2)
print("batch shape of 'independent_normal' : ", independent_normal2.batch_shape)
print("event shape of 'independent_normal' : ", independent_normal2.event_shape)

Normal(loc: torch.Size([2]), scale: torch.Size([2]))
batch shape of 'batched_normal' :  torch.Size([2])
event shape of 'batched_normal' :  torch.Size([])

Independent(Normal(loc: torch.Size([2]), scale: torch.Size([2])), 1)
batch shape of 'independent_normal' :  torch.Size([])
event shape of 'independent_normal' :  torch.Size([2])

Independent(Normal(loc: torch.Size([2]), scale: torch.Size([2])), 1)
batch shape of 'independent_normal' :  torch.Size([])
event shape of 'independent_normal' :  torch.Size([2])


In [59]:
print(batched_normal.log_prob(torch.FloatTensor([-0.2, 1.8])))

tensor([-1.2389, -1.7000])


In [60]:
# return 'scalar' value just as we had with the multivariate normal diag distribution.

print(mv_normal.log_prob(torch.FloatTensor([-0.2, 1.8])))
print(independent_normal.log_prob(torch.FloatTensor([-0.2, 1.8])))

tensor(-2.9239)
tensor(-2.9389)


**왜 다르지 ...?**