# Examples of creating and operating distributions in Pixyz

In [1]:
from __future__ import print_function
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

<torch._C.Generator at 0x7fe4600389b0>

In [2]:
from pixyz.distributions import Normal
from pixyz.utils import print_latex

In [3]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
batch_n = 2

class P1(Normal):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"], name="p_{1}")

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, a, y):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return {"loc": self.fc21(h12), "scale": F.softplus(self.fc22(h12))}

class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"], name="p_{2}")

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return {"loc": self.fc51(h4), "scale": F.softplus(self.fc52(h4))}
    
p4 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["a"], features_shape=[a_dim], name="p_{4}")
p6 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["y"], features_shape=[y_dim], name="p_{6}")
    
x = torch.from_numpy(np.random.random((batch_n, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((batch_n, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((batch_n, a_dim)).astype("float32"))

In [4]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p3.name = "p_{3}"
p5 = p3 * p4
p5.name = "p_{5}"
p_all = p1*p2*p4*p6
p_all.name = "p_{all}"

In [5]:
print(p1)
print_latex(p1)

Distribution:
  p_{1}(x|y,a)
Network architecture:
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(p2)
print_latex(p2)

Distribution:
  p_{2}(z|x,y)
Network architecture:
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
print(p3)
print_latex(p3)

Distribution:
  p_{2}(z,x|y,a) = p_{2}(z|x,y)p_{1}(x|y,a)
Network architecture:
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
print(p4)
print_latex(p4)

Distribution:
  p_{4}(a)
Network architecture:
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )


<IPython.core.display.Math object>

In [9]:
print(p5)
print_latex(p5)

Distribution:
  p_{2}(z,x,a|y) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)
Network architecture:
  p_{4}(a):
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_f

<IPython.core.display.Math object>

In [10]:
print(p_all)
print_latex(p_all)

Distribution:
  p_{1}(x,y,a,z) = p_{2}(z|x,y)p_{1}(x|y,a)p_{4}(a)p_{6}(y)
Network architecture:
  p_{6}(y):
  Normal(
    name=p_{6}, distribution_name=Normal,
    var=['y'], cond_var=[], input_var=[], features_shape=torch.Size([30])
    (loc): torch.Size([1, 30])
    (scale): torch.Size([1, 30])
  )
  p_{4}(a):
  Normal(
    name=p_{4}, distribution_name=Normal,
    var=['a'], cond_var=[], input_var=[], features_shape=torch.Size([50])
    (loc): torch.Size([1, 50])
    (scale): torch.Size([1, 50])
  )
  p_{1}(x|y,a):
  P1(
    name=p_{1}, distribution_name=Normal,
    var=['x'], cond_var=['y', 'a'], input_var=['y', 'a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p_{2}(z|x,y):
  P2(
    name=p_{2}, distribution_name=Normal,
    var=['z'], cond_va

<IPython.core.display.Math object>

In [11]:
for param in p3.parameters():
     print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([30, 20])
<class 'torch.Tensor'> torch.Size([30])
<class 'torch.Tensor'> torch.Size([400, 60])
<class 'torch.Tensor'> torch.Size([400])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([10, 30])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([10, 50])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])


In [12]:
p1.sample({"a":a, "y":y}, return_all=False)

{'x': tensor([[-0.9833, -1.0652,  0.4239,  0.8173, -0.2204,  0.4001,  1.3073, -0.9367,
           0.4676,  0.0725, -0.8273, -0.6580,  0.6714, -1.1744, -1.8351, -0.7814,
           0.7336,  0.1674,  0.3620,  0.0546],
         [-0.2711, -0.6360,  0.0101,  2.1712,  0.6242,  0.4904, -0.8587,  0.4797,
          -0.9147,  1.8677,  0.5589, -0.3913, -0.5078,  0.8452,  0.7145,  1.1389,
           0.3086, -0.6795,  0.0261,  0.2041]])}

In [13]:
p1.sample({"a":a, "y":y}, sample_shape=[5], return_all=False)

{'x': tensor([[[ 2.4950e-01,  3.1602e-01,  1.3576e-01, -5.6682e-01,  1.4176e+00,
           -1.2555e+00,  1.0424e+00, -1.7783e-01,  1.2008e+00, -6.7176e-01,
            8.4457e-01,  6.5844e-01,  7.3068e-05, -3.6145e-01,  1.6219e+00,
            3.5318e-01,  1.6454e+00, -8.1888e-01,  9.0467e-01, -8.9661e-01],
          [ 2.1985e+00,  1.4022e-01, -3.2016e-01,  4.7673e-01,  2.9189e-01,
            7.5498e-01,  3.5511e-01, -5.0099e-01,  6.3490e-03, -8.2951e-01,
            2.2636e-03,  8.9336e-01,  1.0691e+00, -3.2212e-01,  9.7468e-01,
           -1.4594e+00,  1.3640e+00, -1.0668e+00, -6.6085e-01,  3.0555e-01]],
 
         [[ 4.0230e-01,  4.0447e-01,  1.1434e+00,  4.1433e-02, -2.8270e-01,
            4.2656e-01,  2.7652e-02, -1.2106e+00,  5.0271e-01, -8.8879e-01,
           -2.3366e-01,  2.6811e-01,  5.8011e-01, -5.2861e-01,  6.5424e-01,
            7.1672e-01,  2.6231e+00, -2.7531e-01,  1.1890e+00,  3.6978e-01],
          [-4.0754e-01, -1.8754e+00,  1.1188e+00,  2.2412e-01,  1.4475e-01,
 

In [14]:
p1.sample({"a":a, "y":y}, return_all=True)

{'a': tensor([[0.2873, 0.9435, 0.1231, 0.2204, 0.6657, 0.5794, 0.0838, 0.2769, 0.6844,
          0.1111, 0.7587, 0.8047, 0.0507, 0.9262, 0.2786, 0.1218, 0.6858, 0.6872,
          0.2060, 0.1119, 0.5898, 0.1417, 0.5471, 0.3720, 0.8313, 0.9112, 0.6975,
          0.7669, 0.0205, 0.5423, 0.7177, 0.6533, 0.9923, 0.3031, 0.6046, 0.6131,
          0.2235, 0.6828, 0.2267, 0.3587, 0.8634, 0.8237, 0.4217, 0.2676, 0.1127,
          0.1590, 0.5252, 0.5577, 0.5843, 0.3678],
         [0.2433, 0.8302, 0.8332, 0.7060, 0.4069, 0.3686, 0.9766, 0.3705, 0.1188,
          0.6160, 0.1727, 0.4442, 0.7902, 0.4048, 0.5020, 0.3769, 0.8155, 0.1155,
          0.0530, 0.4988, 0.0719, 0.5762, 0.7310, 0.0912, 0.5370, 0.7789, 0.1860,
          0.7515, 0.0460, 0.2896, 0.2969, 0.2924, 0.7538, 0.5971, 0.0058, 0.8062,
          0.3838, 0.0946, 0.9118, 0.6873, 0.9220, 0.3024, 0.2468, 0.2095, 0.5070,
          0.6059, 0.3747, 0.6300, 0.8308, 0.1317]]),
 'y': tensor([[0.4487, 0.0598, 0.7338, 0.3522, 0.7955, 0.7567, 0.8746, 

In [15]:
p1_log_prob = p1.log_prob()
print(p1_log_prob)
print_latex(p1_log_prob)

\log p_{1}(x|y,a)


<IPython.core.display.Math object>

In [16]:
outputs = p1.sample({"y": y, "a": a})
print(p1_log_prob.eval(outputs))

tensor([-25.0882, -25.9153], grad_fn=<SumBackward1>)


In [17]:
outputs = p2.sample({"x":x, "y":y})
print(p2.log_prob().eval(outputs))

tensor([-25.0042, -24.5120], grad_fn=<SumBackward1>)


In [18]:
outputs = p1.sample({"y": y, "a": a})
print(outputs)

{'y': tensor([[0.4487, 0.0598, 0.7338, 0.3522, 0.7955, 0.7567, 0.8746, 0.3631, 0.7995,
         0.2807, 0.9368, 0.0192, 0.9368, 0.8792, 0.2240, 0.1647, 0.9051, 0.8877,
         0.4056, 0.9226, 0.9058, 0.5873, 0.9171, 0.7015, 0.2538, 0.4920, 0.3192,
         0.5458, 0.0588, 0.0675],
        [0.9459, 0.8521, 0.0066, 0.4740, 0.8384, 0.5122, 0.0363, 0.2689, 0.6294,
         0.0306, 0.6115, 0.2558, 0.4377, 0.0827, 0.0874, 0.6609, 0.6415, 0.0991,
         0.4069, 0.6617, 0.9121, 0.0812, 0.1757, 0.2514, 0.3985, 0.0874, 0.2304,
         0.8469, 0.6807, 0.7081]]), 'a': tensor([[0.2873, 0.9435, 0.1231, 0.2204, 0.6657, 0.5794, 0.0838, 0.2769, 0.6844,
         0.1111, 0.7587, 0.8047, 0.0507, 0.9262, 0.2786, 0.1218, 0.6858, 0.6872,
         0.2060, 0.1119, 0.5898, 0.1417, 0.5471, 0.3720, 0.8313, 0.9112, 0.6975,
         0.7669, 0.0205, 0.5423, 0.7177, 0.6533, 0.9923, 0.3031, 0.6046, 0.6131,
         0.2235, 0.6828, 0.2267, 0.3587, 0.8634, 0.8237, 0.4217, 0.2676, 0.1127,
         0.1590, 0.5252, 0.5

In [19]:
p2.sample(outputs)

{'y': tensor([[0.4487, 0.0598, 0.7338, 0.3522, 0.7955, 0.7567, 0.8746, 0.3631, 0.7995,
          0.2807, 0.9368, 0.0192, 0.9368, 0.8792, 0.2240, 0.1647, 0.9051, 0.8877,
          0.4056, 0.9226, 0.9058, 0.5873, 0.9171, 0.7015, 0.2538, 0.4920, 0.3192,
          0.5458, 0.0588, 0.0675],
         [0.9459, 0.8521, 0.0066, 0.4740, 0.8384, 0.5122, 0.0363, 0.2689, 0.6294,
          0.0306, 0.6115, 0.2558, 0.4377, 0.0827, 0.0874, 0.6609, 0.6415, 0.0991,
          0.4069, 0.6617, 0.9121, 0.0812, 0.1757, 0.2514, 0.3985, 0.0874, 0.2304,
          0.8469, 0.6807, 0.7081]]),
 'a': tensor([[0.2873, 0.9435, 0.1231, 0.2204, 0.6657, 0.5794, 0.0838, 0.2769, 0.6844,
          0.1111, 0.7587, 0.8047, 0.0507, 0.9262, 0.2786, 0.1218, 0.6858, 0.6872,
          0.2060, 0.1119, 0.5898, 0.1417, 0.5471, 0.3720, 0.8313, 0.9112, 0.6975,
          0.7669, 0.0205, 0.5423, 0.7177, 0.6533, 0.9923, 0.3031, 0.6046, 0.6131,
          0.2235, 0.6828, 0.2267, 0.3587, 0.8634, 0.8237, 0.4217, 0.2676, 0.1127,
          0.1590

In [20]:
outputs = p3.sample({"y":y, "a":a}, batch_n=batch_n)
print(p3.log_prob().eval(outputs))

tensor([-40.4417, -34.7918], grad_fn=<AddBackward0>)


In [21]:
outputs = p_all.sample(batch_n=batch_n)
print(p_all.log_prob().eval(outputs))

tensor([-155.2850, -152.5048], grad_fn=<AddBackward0>)
