# Examples of creating and operating distributions in Pixyz

In [1]:
from __future__ import print_function
import argparse
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

<torch._C.Generator at 0x7fac2fc9b8b0>

In [2]:
from pixyz.distributions import Normal

In [3]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
n_batch = 2

class P1(Normal):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"], name="p1")

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, a, y):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return {"loc": self.fc21(h12), "scale": F.softplus(self.fc22(h12))}

class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"], name="p2")

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return {"loc": self.fc51(h4), "scale": F.softplus(self.fc52(h4))}
    
p4 = Normal(loc=0, scale=1, var=["a"], dim=a_dim, name="p4")
p6 = Normal(loc=0, scale=1, var=["y"], dim=y_dim, name="p6")
    
x = torch.from_numpy(np.random.random((n_batch, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((n_batch, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((n_batch, a_dim)).astype("float32"))

In [4]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p3.name = "p3"
p5 = p3 * p4
p5.name = "p5"
p_all = p1*p2*p4*p6
p_all.name = "p_all"

In [5]:
print(p1)

Distribution:
  p1(x|y,a) (Normal)
Network architecture:
  P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )


In [6]:
print(p2)

Distribution:
  p2(z|x,y) (Normal)
Network architecture:
  P2(
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


In [7]:
print(p3)

Distribution:
  p3(z,x|y,a) = p2(z|x,y)p1(x|y,a)
Network architecture:
  p1(x|y,a) (Normal): P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p2(z|x,y) (Normal): P2(
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


In [8]:
print(p4)

Distribution:
  p4(a) (Normal)
Network architecture:
  Normal()


In [9]:
print(p5)

Distribution:
  p5(z,x,a|y) = p2(z|x,y)p1(x|y,a)p4(a)
Network architecture:
  p4(a) (Normal): Normal()
  p1(x|y,a) (Normal): P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p2(z|x,y) (Normal): P2(
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


In [10]:
print(p_all)

Distribution:
  p_all(z,x,a,y) = p2(z|x,y)p1(x|y,a)p4(a)p6(y)
Network architecture:
  p6(y) (Normal): Normal()
  p4(a) (Normal): Normal()
  p1(x|y,a) (Normal): P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p2(z|x,y) (Normal): P2(
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )


In [11]:
for param in p3.parameters():
     print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([10, 30])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([10, 50])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([30, 20])
<class 'torch.Tensor'> torch.Size([30])
<class 'torch.Tensor'> torch.Size([400, 60])
<class 'torch.Tensor'> torch.Size([400])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])


In [12]:
p1.sample({"a":a, "y":y}, return_all=False)
#p2.sample({"x":x, "y":y})
#p3.sample({"y":y, "a":a})
#p4.sample()
#p5.sample({"y":y})
#p6.sample()
#p_all.sample()

{'x': tensor([[-1.1235, -1.1559,  0.4218,  0.8778, -0.1497,  0.2739,  1.1814, -0.7278,
           0.2572,  0.1075, -0.7142, -0.7021,  0.6641, -1.1700, -1.8278, -0.9027,
           0.6691,  0.2645,  0.2566, -0.1142],
         [-0.2431, -0.5863, -0.0452,  2.1263,  0.9091,  0.5982, -0.9394,  0.3520,
          -0.7051,  1.8862,  0.4602, -0.2422, -0.6304,  0.8388,  0.8246,  1.1748,
           0.3473, -0.8007,  0.2327,  0.3098]])}

In [13]:
outputs = p1.sample({"y":y, "a":a})
print(p1.log_likelihood(outputs))

outputs = p2.sample({"x":x, "y":y})
print(p2.log_likelihood(outputs))

outputs = p3.sample({"y":y, "a":a})
print(p3.log_likelihood(outputs))

outputs = p_all.sample(batch_size=10)
print(p_all.log_likelihood(outputs))

tensor([-26.2268, -23.8622], grad_fn=<SumBackward2>)
tensor([-18.9722, -19.5073], grad_fn=<SumBackward2>)
tensor([-45.1698, -41.0407], grad_fn=<AddBackward0>)
tensor([-155.4684, -163.4326, -150.2627, -150.2103, -159.1462, -163.7559,
        -168.1021, -162.1275, -160.1595, -142.4833], grad_fn=<AddBackward0>)
