# Examples of creating and operating distributions in Pixyz

In [1]:
from __future__ import print_function
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

<torch._C.Generator at 0x7f944c1310d0>

In [2]:
from pixyz.distributions import Normal
from pixyz.utils import print_latex

In [3]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
batch_n = 2

class P1(Normal):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"], name="p_{1}")

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, a, y):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return {"loc": self.fc21(h12), "scale": F.softplus(self.fc22(h12))}

class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"], name="p_{2}")

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return {"loc": self.fc51(h4), "scale": F.softplus(self.fc52(h4))}
    
p4 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["a"], features_shape=[a_dim], name="p_{4}")
p6 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["y"], features_shape=[y_dim], name="p_{6}")
    
x = torch.from_numpy(np.random.random((batch_n, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((batch_n, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((batch_n, a_dim)).astype("float32"))

In [4]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p3.name = "p_{3}"
p5 = p3 * p4
p5.name = "p_{5}"
p_all = p1*p2*p4*p6
p_all.name = "p_{all}"

In [5]:
print(p1)
print_latex(p1)

Distribution:
  p_{1}(x|y,a)
Network architecture:
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(p2)
print_latex(p2)

Distribution:
  p_{2}(z|x,y)
Network architecture:
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
print(p3)
print_latex(p3)

Distribution:
  p_{3}(z,x|y,a) = p_{2}(z|x,y)p_{1}(x|y,a)
Network architecture:
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
print(p4)
print_latex(p4)

Distribution:
  p_{4}(a)
Network architecture:
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )


<IPython.core.display.Math object>

In [9]:
print(p5)
print_latex(p5)

Distribution:
  p_{5}(z,x,a|y) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)
Network architecture:
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
   

<IPython.core.display.Math object>

In [10]:
print(p_all)
print_latex(p_all)

Distribution:
  p_{all}(z,x,a,y) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)p_{6}(y)
Network architecture:
  Normal(
    name=p_{6}, distribution_name=Normal,
    var=['y'], cond_var=[], input_var=[], features_shape=torch.Size([30])
    (loc): torch.Size([1, 30])
    (scale): torch.Size([1, 30])
  )
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=tor

<IPython.core.display.Math object>

In [11]:
for param in p3.parameters():
     print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([10, 30])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([10, 50])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([30, 20])
<class 'torch.Tensor'> torch.Size([30])
<class 'torch.Tensor'> torch.Size([400, 60])
<class 'torch.Tensor'> torch.Size([400])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])


In [12]:
p1.sample({"a":a, "y":y}, return_all=False)

{'x': tensor([[-1.1218, -1.1758,  0.3717,  0.7961, -0.0242,  0.3305,  1.2300, -0.8699,
           0.4432,  0.1784, -0.6815, -0.5970,  0.7125, -1.2493, -1.8172, -0.8318,
           0.6193,  0.2931,  0.2749,  0.0363],
         [-0.1584, -0.5985,  0.1077,  2.1423,  0.7052,  0.5437, -0.8555,  0.4686,
          -0.9009,  1.7566,  0.5030, -0.2366, -0.4683,  0.9158,  0.7253,  1.1668,
           0.3656, -0.7716,  0.0998,  0.2611]])}

In [13]:
p1.sample({"a":a, "y":y}, sample_shape=[5], return_all=False)

{'x': tensor([[[ 0.2363,  0.3990,  0.0692, -0.5842,  1.4920, -1.3252,  0.9833,
           -0.1309,  1.1616, -0.5395,  1.1062,  0.6611, -0.0148, -0.4342,
            1.4834,  0.2732,  1.4455, -0.7009,  0.8057, -0.8664],
          [ 2.0172,  0.1515, -0.2260,  0.4535,  0.3293,  0.8124,  0.4819,
           -0.4917, -0.0216, -0.8234,  0.0050,  0.9981,  1.1045, -0.1657,
            0.9829, -1.4629,  1.4324, -1.1665, -0.6037,  0.3600]],
 
         [[ 0.4047,  0.4998,  1.1270,  0.0223, -0.0818,  0.3570,  0.0382,
           -1.1366,  0.4775, -0.7488, -0.0467,  0.2880,  0.6136, -0.6018,
            0.5595,  0.6272,  2.3315, -0.1531,  1.0839,  0.3354],
          [-0.2787, -1.7961,  1.2282,  0.2017,  0.1629,  0.0798,  0.3795,
           -1.5618,  0.1000,  0.1157, -0.6087,  0.1748, -0.8421, -0.2185,
            0.4638,  1.1066, -1.2051,  0.1678,  0.4344,  0.3228]],
 
         [[-0.3361,  0.0329,  0.6392, -0.2693,  0.0803, -1.3670,  0.1784,
            0.9852,  1.2060,  0.9109, -0.4192,  1.3437, -0.

In [14]:
p1.sample({"a":a, "y":y}, return_all=True)

{'a': tensor([[0.9683, 0.0151, 0.4744, 0.6016, 0.6464, 0.3306, 0.5251, 0.4898, 0.5678,
          0.0677, 0.0069, 0.7992, 0.1025, 0.8956, 0.1047, 0.9507, 0.7454, 0.1687,
          0.2084, 0.5539, 0.4110, 0.7812, 0.6907, 0.8145, 0.7193, 0.2960, 0.1326,
          0.6641, 0.9409, 0.2728, 0.1460, 0.3379, 0.8449, 0.5248, 0.6381, 0.9087,
          0.5832, 0.9114, 0.0425, 0.5759, 0.2675, 0.1727, 0.7583, 0.9467, 0.8108,
          0.1090, 0.5807, 0.4598, 0.4033, 0.6732],
         [0.1867, 0.0601, 0.6697, 0.5651, 0.5856, 0.1790, 0.3660, 0.0844, 0.1551,
          0.7289, 0.6591, 0.0059, 0.8418, 0.8162, 0.3121, 0.0367, 0.6153, 0.2835,
          0.5682, 0.9936, 0.9931, 0.9135, 0.9406, 0.0304, 0.7227, 0.8995, 0.5438,
          0.2197, 0.9104, 0.0571, 0.3544, 0.5923, 0.6819, 0.4499, 0.7974, 0.0223,
          0.8558, 0.0491, 0.1000, 0.6722, 0.2084, 0.5161, 0.6107, 0.0359, 0.9278,
          0.7299, 0.8748, 0.0903, 0.2951, 0.9971]]),
 'y': tensor([[0.1726, 0.8717, 0.3409, 0.0083, 0.9316, 0.0753, 0.6211, 

In [15]:
p1_log_prob = p1.log_prob()
print(p1_log_prob)
print_latex(p1_log_prob)

\log p_{1}(x|y,a)


<IPython.core.display.Math object>

In [16]:
outputs = p1.sample({"y": y, "a": a})
print(p1_log_prob.eval(outputs))

tensor([-24.9974, -25.6844], grad_fn=<SumBackward2>)


In [17]:
outputs = p2.sample({"x":x, "y":y})
print(p2.log_prob().eval(outputs))

tensor([-24.7590, -24.5750], grad_fn=<SumBackward2>)


In [18]:
outputs = p1.sample({"y": y, "a": a})
print(outputs)

{'y': tensor([[0.1726, 0.8717, 0.3409, 0.0083, 0.9316, 0.0753, 0.6211, 0.2979, 0.5826,
         0.0279, 0.4268, 0.9649, 0.6165, 0.2802, 0.5397, 0.1417, 0.7223, 0.0319,
         0.8294, 0.0267, 0.3987, 0.8499, 0.7430, 0.2817, 0.2454, 0.4912, 0.4041,
         0.0688, 0.4532, 0.2794],
        [0.8561, 0.7470, 0.7878, 0.6870, 0.5151, 0.1039, 0.8625, 0.9268, 0.6096,
         0.6477, 0.3695, 0.6836, 0.6534, 0.3756, 0.3954, 0.8602, 0.2202, 0.5959,
         0.6758, 0.0462, 0.6555, 0.9980, 0.0156, 0.0297, 0.6170, 0.4583, 0.9426,
         0.4794, 0.4864, 0.4973]]), 'a': tensor([[0.9683, 0.0151, 0.4744, 0.6016, 0.6464, 0.3306, 0.5251, 0.4898, 0.5678,
         0.0677, 0.0069, 0.7992, 0.1025, 0.8956, 0.1047, 0.9507, 0.7454, 0.1687,
         0.2084, 0.5539, 0.4110, 0.7812, 0.6907, 0.8145, 0.7193, 0.2960, 0.1326,
         0.6641, 0.9409, 0.2728, 0.1460, 0.3379, 0.8449, 0.5248, 0.6381, 0.9087,
         0.5832, 0.9114, 0.0425, 0.5759, 0.2675, 0.1727, 0.7583, 0.9467, 0.8108,
         0.1090, 0.5807, 0.4

In [19]:
p2.sample(outputs)

{'y': tensor([[0.1726, 0.8717, 0.3409, 0.0083, 0.9316, 0.0753, 0.6211, 0.2979, 0.5826,
          0.0279, 0.4268, 0.9649, 0.6165, 0.2802, 0.5397, 0.1417, 0.7223, 0.0319,
          0.8294, 0.0267, 0.3987, 0.8499, 0.7430, 0.2817, 0.2454, 0.4912, 0.4041,
          0.0688, 0.4532, 0.2794],
         [0.8561, 0.7470, 0.7878, 0.6870, 0.5151, 0.1039, 0.8625, 0.9268, 0.6096,
          0.6477, 0.3695, 0.6836, 0.6534, 0.3756, 0.3954, 0.8602, 0.2202, 0.5959,
          0.6758, 0.0462, 0.6555, 0.9980, 0.0156, 0.0297, 0.6170, 0.4583, 0.9426,
          0.4794, 0.4864, 0.4973]]),
 'a': tensor([[0.9683, 0.0151, 0.4744, 0.6016, 0.6464, 0.3306, 0.5251, 0.4898, 0.5678,
          0.0677, 0.0069, 0.7992, 0.1025, 0.8956, 0.1047, 0.9507, 0.7454, 0.1687,
          0.2084, 0.5539, 0.4110, 0.7812, 0.6907, 0.8145, 0.7193, 0.2960, 0.1326,
          0.6641, 0.9409, 0.2728, 0.1460, 0.3379, 0.8449, 0.5248, 0.6381, 0.9087,
          0.5832, 0.9114, 0.0425, 0.5759, 0.2675, 0.1727, 0.7583, 0.9467, 0.8108,
          0.1090

In [20]:
outputs = p3.sample({"y":y, "a":a}, batch_n=batch_n)
print(p3.log_prob().eval(outputs))

tensor([-40.1329, -34.7207], grad_fn=<AddBackward0>)


In [21]:
outputs = p_all.sample(batch_n=batch_n)
print(p_all.log_prob().eval(outputs))

tensor([-155.2850, -152.5049], grad_fn=<AddBackward0>)
