In [1]:
from __future__ import print_function
import argparse
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

<torch._C.Generator at 0x7fcc9c14d710>

In [2]:
from Tars.distributions import Normal

In [3]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
n_batch = 2

class P1(Normal):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"])

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, a, y):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return {"loc": self.fc21(h12), "scale": F.softplus(self.fc22(h12))}

class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"])

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return {"loc": self.fc51(h4), "scale": F.softplus(self.fc52(h4))}
    
p4 = Normal(loc=0, scale=1, var=["a"], dim=a_dim)
p6 = Normal(loc=0, scale=1, var=["y"], dim=y_dim)
    
x = torch.from_numpy(np.random.random((n_batch, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((n_batch, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((n_batch, a_dim)).astype("float32"))

In [4]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p5 = p3 * p4
p_all = p1*p2*p4*p6

In [9]:
print(p1)
print(p2)
print(p3)
print(p4)
print(p5)
print(p_all)

Distribution:
  p(x|y,a) (Normal)
Network architecture:
  P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
Distribution:
  p(z|x,y) (Normal)
Network architecture:
  P2(
    (fc3): Linear(in_features=20, out_features=30, bias=True)
    (fc4): Linear(in_features=60, out_features=400, bias=True)
    (fc51): Linear(in_features=400, out_features=20, bias=True)
    (fc52): Linear(in_features=400, out_features=20, bias=True)
  )
Distribution:
  p(z,x|y,a) = p(z|x,y)p(x|y,a)
Network architecture:
  p(x|y,a) (Normal): P1(
    (fc1): Linear(in_features=30, out_features=10, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
    (fc21): Linear(in_features=20, out_features=20, bias=True)
    (fc22): Linear(in_features=20, out_features=20, bias=True)
  )
  p(z|x,y) (Normal

In [6]:
for param in p3.parameters():
     print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([10, 30])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([10, 50])
<class 'torch.Tensor'> torch.Size([10])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 20])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([30, 20])
<class 'torch.Tensor'> torch.Size([30])
<class 'torch.Tensor'> torch.Size([400, 60])
<class 'torch.Tensor'> torch.Size([400])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])
<class 'torch.Tensor'> torch.Size([20, 400])
<class 'torch.Tensor'> torch.Size([20])


In [7]:
p1.sample({"a":a, "y":y}, return_all=False)
#p2.sample({"x":x, "y":y})
#p3.sample({"y":y, "a":a})
#p4.sample()
#p5.sample({"y":y})
#p6.sample()
#p_all.sample()

{'x': tensor([[-1.2072, -1.2068,  0.4025,  0.7547, -0.1015,  0.2445,  1.2220,
          -0.5812,  0.2365,  0.1987, -0.6541, -0.5867,  0.6629, -1.1235,
          -1.9549, -1.0275,  0.6051,  0.2591,  0.2361, -0.0200],
         [-0.2084, -0.6144,  0.1112,  2.1750,  0.8007,  0.5263, -0.8072,
           0.4648, -0.9276,  1.8455,  0.5529, -0.2155, -0.6126,  0.9263,
           0.6773,  1.1610,  0.3911, -0.8201,  0.1603,  0.1784]])}

In [8]:
outputs = p1.sample({"y":y, "a":a})
print(p1.log_likelihood(outputs))

outputs = p2.sample({"x":x, "y":y})
print(p2.log_likelihood(outputs))

outputs = p3.sample({"y":y, "a":a})
print(p3.log_likelihood(outputs))

outputs = p_all.sample(batch_size=10)
print(p_all.log_likelihood(outputs))

tensor([-26.3069, -23.4132])
tensor([-19.1880, -19.4689])
tensor([-45.3444, -40.6475])
tensor([-155.4684, -163.4326, -150.2627, -150.2103, -159.1462, -163.7559,
        -168.1021, -162.1275, -160.1595, -142.4833])
