In [None]:
from __future__ import print_function
import argparse
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

torch.manual_seed(1)

In [None]:
from Tars.distributions import GaussianModel

In [None]:
x_dim = 20
y_dim = 30
z_dim = 40
a_dim = 50
n_batch = 2

class P1(GaussianModel):
    def __init__(self):
        super(P1, self).__init__(cond_var=["y", "a"], var=["x"])

        self.fc1 = nn.Linear(y_dim, 10)
        self.fc2 = nn.Linear(a_dim, 10)
        self.fc21 = nn.Linear(10+10, 20)
        self.fc22 = nn.Linear(10+10, 20)

    def forward(self, y, a):
        h1 = F.relu(self.fc1(y))
        h2 = F.relu(self.fc2(a))
        h12 = torch.cat([h1, h2], 1)
        return self.fc21(h12), F.softplus(self.fc22(h12))

class P2(GaussianModel):
    def __init__(self):
        super(P2, self).__init__(cond_var=["x", "y"], var=["z"])

        self.fc3 = nn.Linear(x_dim, 30)
        self.fc4 = nn.Linear(30+y_dim, 400)
        self.fc51 = nn.Linear(400, 20)
        self.fc52 = nn.Linear(400, 20)

    def forward(self, x, y):
        h3 = F.relu(self.fc3(x))
        h4 = F.relu(self.fc4(torch.cat([h3, y], 1)))
        return self.fc51(h4), F.softplus(self.fc52(h4))
    
p4 = GaussianModel(loc=0, scale=1, var=["a"], dim=a_dim)
p6 = GaussianModel(loc=0, scale=1, var=["y"], dim=y_dim)
    
x = torch.from_numpy(np.random.random((n_batch, x_dim)).astype("float32"))
y = torch.from_numpy(np.random.random((n_batch, y_dim)).astype("float32"))
a = torch.from_numpy(np.random.random((n_batch, a_dim)).astype("float32"))

In [None]:
p1 = P1()
p2 = P2()
p3 = p2 * p1
p5 = p3 * p4
p_all = p1*p2*p4*p6

In [None]:
print(p1.prob_text)
print(p2.prob_text)
print(p3.prob_text)
print(p4.prob_text)
print(p5.prob_text, p5.prob_factorized_text)
print(p_all.prob_text, p_all.prob_factorized_text)

In [None]:
for param in p3.parameters():
     print(type(param.data), param.size())

In [None]:
p1.sample({"a":a, "y":y}, return_all=False)
#p2.sample({"x":x, "y":y})
#p3.sample({"y":y, "a":a})
#p4.sample()
#p5.sample({"y":y})
#p6.sample()
#p_all.sample()

In [None]:
outputs = p1.sample({"y":y, "a":a})
print(p1.log_likelihood(outputs))

outputs = p2.sample({"x":x, "y":y})
print(p2.log_likelihood(outputs))

outputs = p3.sample({"y":y, "a":a})
print(p3.log_likelihood(outputs))

outputs = p_all.sample(batch_size=10)
print(p_all.log_likelihood(outputs))