# Examples of creating and operating distributions in Pixyz

In [1]:
from __future__ import print_function
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

<torch._C.Generator at 0x7f3ab00939d0>

In [2]:
from pixyz.distributions import Normal
from pixyz.utils import print_latex

In [3]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
batch_n = 2

class P1(Normal):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"], name="p_{1}")

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, a, y):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return {"loc": self.fc21(h12), "scale": F.softplus(self.fc22(h12))}

class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"], name="p_{2}")

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return {"loc": self.fc51(h4), "scale": F.softplus(self.fc52(h4))}
    
p4 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["a"], features_shape=[a_dim], name="p_{4}")
p6 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["y"], features_shape=[y_dim], name="p_{6}")
    
x = torch.from_numpy(np.random.random((batch_n, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((batch_n, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((batch_n, a_dim)).astype("float32"))

In [4]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p3.name = "p_{3}"
p5 = p3 * p4
p5.name = "p_{5}"
p_all = p1*p2*p4*p6
p_all.name = "p_{all}"

In [5]:
print(p1)
print_latex(p1)

Distribution:
  p_{1}(x|y,a)
Network architecture:
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(p2)
print_latex(p2)

Distribution:
  p_{2}(z|x,y)
Network architecture:
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
print(p3)
print_latex(p3)

Distribution:
  p_{2}(z,x|y,a) = p_{2}(z|x,y)p_{1}(x|y,a)
Network architecture:
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
print(p4)
print_latex(p4)

Distribution:
  p_{4}(a)
Network architecture:
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )


<IPython.core.display.Math object>

In [9]:
print(p5)
print_latex(p5)

Distribution:
  p_{2}(z,x,a|y) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)
Network architecture:
  p_{4}(a):
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_f

<IPython.core.display.Math object>

In [10]:
print(p_all)
print_latex(p_all)

Distribution:
  p_{1}(x,y,a,z) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)p_{6}(y)
Network architecture:
  p_{6}(y):
  Normal(
    name=p_{6}, distribution_name=Normal,
    var=['y'], cond_var=[], input_var=[], features_shape=torch.Size([30])
    (loc): torch.Size([1, 30])
    (scale): torch.Size([1, 30])
  )
  p_{4}(a):
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_va

<IPython.core.display.Math object>

In [11]:
for param in p3.parameters():
     print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([30, 20])
<class 'torch.Tensor'> torch.Size([30])
<class 'torch.Tensor'> torch.Size([400, 60])
<class 'torch.Tensor'> torch.Size([400])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([10, 30])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([10, 50])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])


In [12]:
p1.sample({"a":a, "y":y}, return_all=False)

{'x': tensor([[-1.0299, -1.2263,  0.5289,  0.9194, -0.2101,  0.2768,  1.3235, -0.8458,
           0.3476,  0.0318, -0.7057, -0.5483,  0.6671, -0.8597, -1.7904, -0.8084,
           0.6117,  0.2948,  0.3227, -0.0383],
         [-0.3045, -0.6468,  0.0891,  2.2200,  0.5947,  0.5854, -0.9290,  0.4790,
          -0.7963,  1.9300,  0.5019, -0.2916, -0.5468,  0.7446,  0.8173,  1.0218,
           0.2750, -0.7355,  0.1038,  0.1159]])}

In [13]:
p1.sample({"a":a, "y":y}, sample_shape=[5], return_all=False)

{'x': tensor([[[ 1.8714e-01,  3.6696e-01,  2.1010e-01, -5.5118e-01,  1.2658e+00,
           -1.3178e+00,  1.0771e+00, -1.3143e-01,  1.0405e+00, -7.2369e-01,
            9.5043e-01,  6.9438e-01, -6.4541e-02, -1.1016e-01,  1.4737e+00,
            3.0139e-01,  1.4404e+00, -7.8267e-01,  8.4115e-01, -9.3408e-01],
          [ 2.0706e+00,  1.0844e-01, -2.5218e-01,  5.2909e-01,  2.5337e-01,
            8.5192e-01,  3.5012e-01, -5.5830e-01,  4.4407e-02, -8.1839e-01,
           -2.2697e-02,  1.0360e+00,  9.2339e-01, -3.9723e-01,  1.0899e+00,
           -1.5907e+00,  1.4165e+00, -1.1127e+00, -5.8682e-01,  2.2109e-01]],
 
         [[ 3.3799e-01,  4.6899e-01,  1.3249e+00,  9.5066e-02, -2.6627e-01,
            3.0221e-01,  1.3340e-01, -1.1036e+00,  3.8074e-01, -9.4399e-01,
           -1.1762e-01,  3.2594e-01,  5.6761e-01, -2.6429e-01,  5.6004e-01,
            6.5701e-01,  2.3289e+00, -1.8887e-01,  1.1128e+00,  2.5848e-01],
          [-4.3580e-01, -1.8528e+00,  1.2349e+00,  2.7701e-01,  1.0223e-01,
 

In [14]:
p1.sample({"a":a, "y":y}, return_all=True)

{'a': tensor([[0.1247, 0.3626, 0.5750, 0.3316, 0.1122, 0.2393, 0.9769, 0.6199, 0.4377,
          0.0989, 0.6779, 0.9160, 0.0729, 0.7601, 0.6406, 0.2142, 0.1419, 0.0939,
          0.8958, 0.6005, 0.7959, 0.5817, 0.9536, 0.6578, 0.5547, 0.4744, 0.5304,
          0.9512, 0.9543, 0.3990, 0.5173, 0.8549, 0.4733, 0.7530, 0.6687, 0.0554,
          0.4222, 0.7571, 0.9099, 0.3253, 0.8394, 0.0964, 0.1431, 0.4612, 0.3356,
          0.8926, 0.5417, 0.5819, 0.7935, 0.4104],
         [0.0295, 0.4776, 0.7399, 0.7366, 0.3285, 0.4687, 0.5953, 0.8438, 0.6898,
          0.2132, 0.0071, 0.4417, 0.1486, 0.1484, 0.1871, 0.2375, 0.4639, 0.8257,
          0.5834, 0.7746, 0.1868, 0.9428, 0.1989, 0.3176, 0.9830, 0.1743, 0.0506,
          0.6278, 0.2065, 0.2463, 0.5054, 0.2877, 0.5163, 0.9787, 0.4039, 0.0271,
          0.8808, 0.5863, 0.6645, 0.2771, 0.6532, 0.2723, 0.8658, 0.5841, 0.6185,
          0.9491, 0.0068, 0.0945, 0.0780, 0.1591]]),
 'y': tensor([[0.4249, 0.2331, 0.9856, 0.6609, 0.2501, 0.4779, 0.5012, 

In [15]:
p1_log_prob = p1.log_prob()
print(p1_log_prob)
print_latex(p1_log_prob)

\log p_{1}(x|y,a)


<IPython.core.display.Math object>

In [16]:
outputs = p1.sample({"y": y, "a": a})
print(p1_log_prob.eval(outputs))

tensor([-24.8088, -25.9759], grad_fn=<SumBackward1>)


In [17]:
outputs = p2.sample({"x":x, "y":y})
print(p2.log_prob().eval(outputs))

tensor([-24.9166, -24.2261], grad_fn=<SumBackward1>)


In [18]:
outputs = p1.sample({"y": y, "a": a})
print(outputs)

{'y': tensor([[0.4249, 0.2331, 0.9856, 0.6609, 0.2501, 0.4779, 0.5012, 0.2993, 0.2764,
         0.7004, 0.4063, 0.8842, 0.9118, 0.9472, 0.0198, 0.5519, 0.7312, 0.6658,
         0.0805, 0.6665, 0.4161, 0.3674, 0.7932, 0.4358, 0.8718, 0.8125, 0.6108,
         0.8514, 0.1534, 0.4207],
        [0.4866, 0.2115, 0.4247, 0.6349, 0.6166, 0.3609, 0.1176, 0.6920, 0.1982,
         0.7056, 0.2711, 0.7976, 0.0407, 0.0423, 0.3801, 0.2085, 0.1222, 0.3232,
         0.2642, 0.4416, 0.8992, 0.8277, 0.1803, 0.3078, 0.6140, 0.9737, 0.8137,
         0.0308, 0.3883, 0.1112]]), 'a': tensor([[0.1247, 0.3626, 0.5750, 0.3316, 0.1122, 0.2393, 0.9769, 0.6199, 0.4377,
         0.0989, 0.6779, 0.9160, 0.0729, 0.7601, 0.6406, 0.2142, 0.1419, 0.0939,
         0.8958, 0.6005, 0.7959, 0.5817, 0.9536, 0.6578, 0.5547, 0.4744, 0.5304,
         0.9512, 0.9543, 0.3990, 0.5173, 0.8549, 0.4733, 0.7530, 0.6687, 0.0554,
         0.4222, 0.7571, 0.9099, 0.3253, 0.8394, 0.0964, 0.1431, 0.4612, 0.3356,
         0.8926, 0.5417, 0.5

In [19]:
p2.sample(outputs)

{'y': tensor([[0.4249, 0.2331, 0.9856, 0.6609, 0.2501, 0.4779, 0.5012, 0.2993, 0.2764,
          0.7004, 0.4063, 0.8842, 0.9118, 0.9472, 0.0198, 0.5519, 0.7312, 0.6658,
          0.0805, 0.6665, 0.4161, 0.3674, 0.7932, 0.4358, 0.8718, 0.8125, 0.6108,
          0.8514, 0.1534, 0.4207],
         [0.4866, 0.2115, 0.4247, 0.6349, 0.6166, 0.3609, 0.1176, 0.6920, 0.1982,
          0.7056, 0.2711, 0.7976, 0.0407, 0.0423, 0.3801, 0.2085, 0.1222, 0.3232,
          0.2642, 0.4416, 0.8992, 0.8277, 0.1803, 0.3078, 0.6140, 0.9737, 0.8137,
          0.0308, 0.3883, 0.1112]]),
 'a': tensor([[0.1247, 0.3626, 0.5750, 0.3316, 0.1122, 0.2393, 0.9769, 0.6199, 0.4377,
          0.0989, 0.6779, 0.9160, 0.0729, 0.7601, 0.6406, 0.2142, 0.1419, 0.0939,
          0.8958, 0.6005, 0.7959, 0.5817, 0.9536, 0.6578, 0.5547, 0.4744, 0.5304,
          0.9512, 0.9543, 0.3990, 0.5173, 0.8549, 0.4733, 0.7530, 0.6687, 0.0554,
          0.4222, 0.7571, 0.9099, 0.3253, 0.8394, 0.0964, 0.1431, 0.4612, 0.3356,
          0.8926

In [20]:
outputs = p3.sample({"y":y, "a":a}, batch_n=batch_n)
print(p3.log_prob().eval(outputs))

tensor([-40.1202, -34.6922], grad_fn=<AddBackward0>)


In [21]:
outputs = p_all.sample(batch_n=batch_n)
print(p_all.log_prob().eval(outputs))

tensor([-155.2850, -152.5048], grad_fn=<AddBackward0>)
