In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli
from Tars.models import VI
from Tars.utils import get_dict_values

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
a_dim = 64
z_dim = 32

# inference models
class Q1(Normal):
    def __init__(self):
        super(Q1, self).__init__(cond_var=["x"], var=["a"])

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, a_dim)
        self.fc32 = nn.Linear(512, a_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
class Q2(Normal):
    def __init__(self):
        super(Q2, self).__init__(cond_var=["x"], var=["z"])

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# generative models
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
p1 = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)
    
class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["z"], var=["a"])

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, a_dim)
        self.fc32 = nn.Linear(512, a_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
class P3(Bernoulli):
    def __init__(self):
        super(P3, self).__init__(cond_var=["a"], var=["x"])

        self.fc1 = nn.Linear(a_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, a):
        h = F.relu(self.fc1(a))
        h = F.relu(self.fc2(h))
        return {"probs": F.sigmoid(self.fc3(h))}
    
    
q1 = Q1()
q2 = Q2() 

p2 = P2()
p3 = P3()

In [5]:
q = q1 * q2
_p = p2 * p3
p = _p * p1
p.to(device)
q.to(device)

print(p.prob_factorized_text, p.prob_text)
print(_p.prob_factorized_text, _p.prob_text)
print(q.prob_factorized_text, q.prob_text)

p(x|a)p(a|z)p(z) p(x,a,z)
p(x|a)p(a|z) p(x,a|z)
p(z|x)p(a|x) p(z,a|x)


In [6]:
model = VI(p, q, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [7]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [8]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [9]:
def plot_reconstrunction(data):
    with torch.no_grad():
        z = q.sample({"x": data.view(-1, 784)})
        z = get_dict_values(z, _p.cond_var, return_dict=True) # select latent variables
        recon_batch = _p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = _p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [10]:
writer = SummaryWriter()

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:07<00:00, 64.95it/s]

Epoch: 1 Train loss: 186.2912



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 150.8370


100%|██████████| 469/469 [00:07<00:00, 62.62it/s]

Epoch: 2 Train loss: 136.7231



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 129.1961


100%|██████████| 469/469 [00:07<00:00, 63.52it/s]

Epoch: 3 Train loss: 123.9782



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 121.0725


100%|██████████| 469/469 [00:07<00:00, 62.91it/s]


Epoch: 4 Train loss: 117.9895


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.8547


100%|██████████| 469/469 [00:06<00:00, 67.55it/s]

Epoch: 5 Train loss: 113.5125





Test loss: 113.0624
