# Variational inference on a hierarchical latent model

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.models import VI
from pixyz.utils import print_latex

In [4]:
x_dim = 784
a_dim = 64
z_dim = 32


# inference models
class Q1(Normal):
    def __init__(self):
        super(Q1, self).__init__(cond_var=["x"], var=["a"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, a_dim)
        self.fc32 = nn.Linear(512, a_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    

class Q2(Normal):
    def __init__(self):
        super(Q2, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    

q1 = Q1().to(device)
q2 = Q2().to(device)

q = q1 * q2
q.name = "q"
    
# generative models
class P2(Normal):
    def __init__(self):
        super(P2, self).__init__(cond_var=["z"], var=["a"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, a_dim)
        self.fc32 = nn.Linear(512, a_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
    
class P3(Bernoulli):
    def __init__(self):
        super(P3, self).__init__(cond_var=["a"], var=["x"], name="p")

        self.fc1 = nn.Linear(a_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, a):
        h = F.relu(self.fc1(a))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}


p2 = P2().to(device)
p3 = P3().to(device)

p1 = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
            var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

_p = p2 * p3
p = _p * p1

In [5]:
print(p)
print_latex(p)

Distribution:
  p(x,a,z) = p(x|a)p(a|z)p_{prior}(z)
Network architecture:
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([32])
    (loc): torch.Size([1, 32])
    (scale): torch.Size([1, 32])
  )
  P2(
    name=p, distribution_name=Normal,
    var=['a'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=32, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=64, bias=True)
    (fc32): Linear(in_features=512, out_features=64, bias=True)
  )
  P3(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['a'], input_var=['a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(_p)
print_latex(_p)

Distribution:
  p(x,a|z) = p(x|a)p(a|z)
Network architecture:
  P2(
    name=p, distribution_name=Normal,
    var=['a'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=32, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=64, bias=True)
    (fc32): Linear(in_features=512, out_features=64, bias=True)
  )
  P3(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['a'], input_var=['a'], features_shape=torch.Size([])
    (fc1): Linear(in_features=64, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
print(q)
print_latex(q)

Distribution:
  q(a,z|x) = q(a|x)q(z|x)
Network architecture:
  Q2(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=32, bias=True)
    (fc32): Linear(in_features=512, out_features=32, bias=True)
  )
  Q1(
    name=q, distribution_name=Normal,
    var=['a'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=64, bias=True)
    (fc32): Linear(in_features=512, out_features=64, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
model = VI(p, q, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x,a,z), q(a,z|x) 
Loss function: 
  - mean \left(\mathbb{E}_{q(a,z|x)} \left[\log p(x,a,z) - \log q(a,z|x) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [9]:
def train(epoch):
    train_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [10]:
def test(epoch):
    test_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [11]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x})
        z = z.extract(_p.cond_var, return_dict=True) # select latent variables
        recon_batch = _p.sample(z)["x"].view(-1, 1, 28, 28) # TODO: it should be sample_mean
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = _p.sample({"z": z_sample})["x"].view(-1, 1, 28, 28).cpu() # TODO: it should be sample_mean
        return sample

In [12]:
writer = SummaryWriter()

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:06<00:00, 67.02it/s]

Epoch: 1 Train loss: 185.3116



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 147.5463


100%|██████████| 469/469 [00:06<00:00, 67.23it/s]

Epoch: 2 Train loss: 134.1564



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 127.0369


100%|██████████| 469/469 [00:06<00:00, 67.38it/s]


Epoch: 3 Train loss: 122.5826


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.4241


100%|██████████| 469/469 [00:07<00:00, 66.66it/s]

Epoch: 4 Train loss: 116.6982



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 114.9412


100%|██████████| 469/469 [00:07<00:00, 63.65it/s]

Epoch: 5 Train loss: 112.8898



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 112.0781


100%|██████████| 469/469 [00:06<00:00, 67.12it/s]


Epoch: 6 Train loss: 110.4185


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 109.9266


100%|██████████| 469/469 [00:07<00:00, 65.79it/s]


Epoch: 7 Train loss: 108.5817


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 108.3503


100%|██████████| 469/469 [00:07<00:00, 61.78it/s]

Epoch: 8 Train loss: 106.9110



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.4023


100%|██████████| 469/469 [00:07<00:00, 65.14it/s]


Epoch: 9 Train loss: 105.6909


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.2646


100%|██████████| 469/469 [00:07<00:00, 27.98it/s]


Epoch: 10 Train loss: 104.6131
Test loss: 105.6856
