In [49]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

import matplotlib.pyplot as plt
import numpy as np

# sampling and log_prob [tensorflow_probability]

**simple example**
- exponential distribution

In [50]:
exp = tfd.Exponential(rate=[[1., 1.5, .8],
                            [.3, .4, 1.8]])
print(exp)

tfp.distributions.Exponential("Exponential", batch_shape=[2, 3], event_shape=[], dtype=float32)


In [51]:
# with no "reinterpreted_batch_ndims" keyword argument, get default value(first batch_dimension(3))

ind_exp = tfd.Independent(exp)
print(ind_exp)

tfp.distributions.Independent("IndependentExponential", batch_shape=[2], event_shape=[3], dtype=float32)


In [52]:
ind_exp.sample(4)
print("whole shape of independent distribution(joint D) of batched exps : ", ind_exp.sample(4).shape)
print("sample shape : ", ind_exp.sample(4).shape[0])
print("batch shape : ", ind_exp.sample(4).shape[1])
print("event shape : ", ind_exp.sample(4).shape[2])

whole shape of independent distribution(joint D) of batched exps :  (4, 2, 3)
sample shape :  4
batch shape :  2
event shape :  3


**more complicated example**
- exponential distribution
    - rank 4 params

In [53]:
rates = [
    [[[1., 1.5, .8], [.3, .4, 1.8]]],
    [[[.2, .4, 1.4], [.4, 1.1, .9]]]
]

exp = tfd.Exponential(rate=rates)
print(exp)

tfp.distributions.Exponential("Exponential", batch_shape=[2, 1, 2, 3], event_shape=[], dtype=float32)


In [54]:
ind_exp = tfd.Independent(exp, reinterpreted_batch_ndims=2)
print(ind_exp)

tfp.distributions.Independent("IndependentExponential", batch_shape=[2, 1], event_shape=[2, 3], dtype=float32)


## "sampling" of batched multi-event distribution

In [55]:
# shape of "ind_exp.sample([4,2])" : (4,2, 2,1, 2,3)
# (4,2 : sample shape
#  2,1 : batch sahpe
#  2,3 : event shape

print("ind_exp sample([4,2]) shape : ", ind_exp.sample([4,2]).shape)
print()
print(ind_exp.sample([4,2])[0])

ind_exp sample([4,2]) shape :  (4, 2, 2, 1, 2, 3)

tf.Tensor(
[[[[[1.4325756  0.0644755  4.535669  ]
    [0.8761179  0.9497326  0.00807053]]]


  [[[0.88344014 3.3059225  1.1386163 ]
    [7.113349   1.410286   0.67349964]]]]



 [[[[0.75737804 0.09854684 0.19304658]
    [1.0603809  0.4898638  0.53262   ]]]


  [[[4.989893   1.8158811  2.899054  ]
    [5.6882505  1.3662548  0.49126583]]]]], shape=(2, 2, 1, 2, 3), dtype=float32)


## "log_prob" of batched multi-event distribution

- broad-casting rule
    - 만약, ind_exp의 batch_shape=[2,1], event_shape=[2,3]이라면,
    - ind_exp.log_prob(0.5) 값의 shape는 batch_shape와 동일하게 나온다.([2,1])
        - 어떻게 scalar를 입력했는 데, batch_shape와 동일하게 출력되는 걸까?
            - broadcasting 연산
                - 임의로 event_shape에 모두 동일한 숫자로 입력해 연산
                
                
```python
    ind_exp.log_prob(0.5) == ind_exp.log_prob([ [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] ])  # shape=(2,3)
    ind_exp.log_prob([[0.5,0.5,0.5]]) == ind_exp.log_prob([ [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] ]) 
```

In [56]:
print(ind_exp)
print()

print(ind_exp.log_prob(0.5))
print()
print("shape of ind_exp.log_prob(0.5) : ", ind_exp.log_prob(0.5).shape)

tfp.distributions.Independent("IndependentExponential", batch_shape=[2, 1], event_shape=[2, 3], dtype=float32)

tf.Tensor(
[[-4.2501554]
 [-5.3155975]], shape=(2, 1), dtype=float32)

shape of ind_exp.log_prob(0.5) :  (2, 1)


In [57]:
# 위와 동일하므로, broad-casting 연산이 적용되었음을 알 수 있다.

print(ind_exp.log_prob([
    [0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5]]))

tf.Tensor(
[[-4.2501554]
 [-5.3155975]], shape=(2, 1), dtype=float32)


## another example

- rank 5 input을 가정해보자.
```python
tf.random.uniform((5,1,1,2,1))
```

```
    - 이 때, [B, E] = [2, 1, 2, 3]인 분포(위에서 정의한 "ind_exp")에도 broad casting이 가능할 까?
        - 가능하다면, returned by log prob method의 shape는?
        
        - 아직 명확하지 않음...... 더 공부해봐야 함
        
        - answer) 가능함.
            - why?) input 의 (5,1,1,2,1) 중 
                sample shape : 5         -> be broadcast against the rest of ..(5)
                batch shape : (1,1,2)    -> be broadcast against the first dimension of the batch_shape(2)
                event shape : 1          -> be broadcast against the second dimension of the event_shape(3)
                
```

In [77]:
ind_exp

<tfp.distributions.Independent 'IndependentExponential' batch_shape=[2, 1] event_shape=[2, 3] dtype=float32>

In [79]:
input_sample = tf.random.uniform((5,1,1,2,1))
print(input_sample.shape)

(5, 1, 1, 2, 1)


In [59]:
ind_exp.log_prob(tf.random.uniform((5,1,1,2,1)))

<tf.Tensor: shape=(5, 2, 1), dtype=float32, numpy=
array([[[-4.7249455],
        [-5.4074545]],

       [[-5.269409 ],
        [-6.346131 ]],

       [[-3.1244216],
        [-4.3141217]],

       [[-6.0522842],
        [-6.6951456]],

       [[-3.3583426],
        [-4.5451894]]], dtype=float32)>

**example**

In [95]:
loc = tf.zeros((2,3,1))
scale_diag = tf.ones(4)
print("batched_mv : \n", tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag))
print()

dist = tfd.Independent(tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag))
print("indep dist : \n", dist)

batched_mv : 
 tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[2, 3], event_shape=[4], dtype=float32)

indep dist : 
 tfp.distributions.Independent("IndependentMultivariateNormalDiag", batch_shape=[2], event_shape=[3, 4], dtype=float32)


In [91]:
print(dist)

tfp.distributions.Independent("IndependentMultivariateNormalDiag", batch_shape=[2], event_shape=[3, 4], dtype=float32)


In [89]:
dist.log_prob(tf.random.uniform((2,1,1,4)))

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-12.978109, -12.978109],
       [-14.191105, -14.191105]], dtype=float32)>