# Original variational autoencoder

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, Bernoulli, Deterministic
from pixyz.losses import KullbackLeibler, CrossEntropy
from pixyz.models import Model
from pixyz.utils import print_latex

In [4]:
x_dim = 784
z_dim = 8


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
p = Generator().to(device)
q = Inference().to(device)
    
# prior model p(z)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"],
               features_shape=[z_dim], name="p_prior").to(device)

In [5]:
print(p)
print_latex(p)

Distribution:
  p(x|z)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=8, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(q)
print_latex(q)

Distribution:
  q(z|x)
Network architecture:
  Inference(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc31): Linear(in_features=512, out_features=8, bias=True)
    (fc32): Linear(in_features=512, out_features=8, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
reconst = -p.log_prob().expectation(q)
kl = KullbackLeibler(q, prior)
loss_cls = reconst.mean() + kl.mean()
print(loss_cls)
print_latex(loss_cls)

mean \left(- \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] \right)


<IPython.core.display.Math object>

In [8]:
model = Model(loss_cls, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x|z), q(z|x) 
Loss function: 
  mean \left(- \mathbb{E}_{q(z|x)} \left[\log p(x|z) \right] \right) + mean \left(D_{KL} \left[q(z|x)||p_{prior}(z) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [9]:
def train(epoch):
    train_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss.item()))
    return train_loss

In [10]:
def test(epoch):
    test_loss = 0
    for x, _ in tqdm(test_loader):
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss
 
    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss.item()))
    return test_loss

In [11]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [None]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [12]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("../runs/" + v + ".factorvae-baseline" + exp_time)

plot_dim = 8

z_sample = []
for i in range(plot_dim):
    z_batch = torch.zeros(plot_dim, z_dim)
    z_batch[:, i] = (torch.arange(plot_dim,dtype=torch.float32)*2.)/(plot_dim-1.)-1
    z_sample.append(z_batch)
z_sample = torch.cat(z_sample, dim=0).to(device)
#z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()

100%|██████████| 469/469 [00:06<00:00, 69.55it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 1 Train loss: 168.7405


100%|██████████| 79/79 [00:01<00:00, 74.84it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 130.9974


100%|██████████| 469/469 [00:06<00:00, 74.02it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 2 Train loss: 124.0870


100%|██████████| 79/79 [00:01<00:00, 77.35it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 119.4285


100%|██████████| 469/469 [00:06<00:00, 73.63it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 3 Train loss: 116.7694


100%|██████████| 79/79 [00:01<00:00, 75.62it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 115.4097


100%|██████████| 469/469 [00:06<00:00, 71.87it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 4 Train loss: 113.6266


100%|██████████| 79/79 [00:01<00:00, 76.22it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 113.0323


100%|██████████| 469/469 [00:06<00:00, 72.72it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 5 Train loss: 111.6023


100%|██████████| 79/79 [00:01<00:00, 69.73it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 111.4387


100%|██████████| 469/469 [00:06<00:00, 75.90it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 6 Train loss: 110.0827


100%|██████████| 79/79 [00:00<00:00, 84.34it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 110.1414


100%|██████████| 469/469 [00:06<00:00, 70.96it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 7 Train loss: 108.9575


100%|██████████| 79/79 [00:01<00:00, 74.03it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 109.2630


100%|██████████| 469/469 [00:06<00:00, 69.78it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 8 Train loss: 107.9880


100%|██████████| 79/79 [00:01<00:00, 75.74it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 108.4400


100%|██████████| 469/469 [00:06<00:00, 69.42it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 9 Train loss: 107.2631


100%|██████████| 79/79 [00:01<00:00, 55.12it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 108.0609


100%|██████████| 469/469 [00:06<00:00, 68.38it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 10 Train loss: 106.5247


100%|██████████| 79/79 [00:01<00:00, 78.37it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.6435


100%|██████████| 469/469 [00:06<00:00, 71.46it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 11 Train loss: 105.9431


100%|██████████| 79/79 [00:01<00:00, 74.53it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 107.1741


100%|██████████| 469/469 [00:06<00:00, 71.53it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 12 Train loss: 105.4531


100%|██████████| 79/79 [00:01<00:00, 77.84it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.9526


100%|██████████| 469/469 [00:06<00:00, 72.38it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 13 Train loss: 105.0054


100%|██████████| 79/79 [00:01<00:00, 77.82it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.5668


100%|██████████| 469/469 [00:07<00:00, 61.98it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 14 Train loss: 104.6023


100%|██████████| 79/79 [00:01<00:00, 76.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 106.1027


100%|██████████| 469/469 [00:06<00:00, 67.64it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 15 Train loss: 104.2539


100%|██████████| 79/79 [00:01<00:00, 65.10it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.8058


100%|██████████| 469/469 [00:07<00:00, 63.52it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 16 Train loss: 103.8652


100%|██████████| 79/79 [00:01<00:00, 76.49it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.6606


100%|██████████| 469/469 [00:06<00:00, 70.79it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 17 Train loss: 103.5775


100%|██████████| 79/79 [00:01<00:00, 76.95it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.4449


100%|██████████| 469/469 [00:06<00:00, 72.01it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 18 Train loss: 103.3625


100%|██████████| 79/79 [00:01<00:00, 73.49it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.2553


100%|██████████| 469/469 [00:06<00:00, 70.59it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 19 Train loss: 103.0787


100%|██████████| 79/79 [00:01<00:00, 77.06it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.0151


100%|██████████| 469/469 [00:06<00:00, 71.46it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 20 Train loss: 102.7914


100%|██████████| 79/79 [00:01<00:00, 71.76it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 105.0150


100%|██████████| 469/469 [00:06<00:00, 68.99it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 21 Train loss: 102.5817


100%|██████████| 79/79 [00:01<00:00, 72.22it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.8675


100%|██████████| 469/469 [00:06<00:00, 71.39it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 22 Train loss: 102.4381


100%|██████████| 79/79 [00:01<00:00, 76.45it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.8971


100%|██████████| 469/469 [00:07<00:00, 63.46it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 23 Train loss: 102.2058


100%|██████████| 79/79 [00:01<00:00, 47.06it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.6963


100%|██████████| 469/469 [00:06<00:00, 70.65it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 24 Train loss: 102.0070


100%|██████████| 79/79 [00:01<00:00, 68.81it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.8087


100%|██████████| 469/469 [00:06<00:00, 70.32it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 25 Train loss: 101.7603


100%|██████████| 79/79 [00:01<00:00, 74.92it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.5486


100%|██████████| 469/469 [00:06<00:00, 70.77it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 26 Train loss: 101.6905


100%|██████████| 79/79 [00:01<00:00, 76.51it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.6585


100%|██████████| 469/469 [00:06<00:00, 71.44it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 27 Train loss: 101.5281


100%|██████████| 79/79 [00:01<00:00, 74.50it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.7171


100%|██████████| 469/469 [00:06<00:00, 70.65it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 28 Train loss: 101.3781


100%|██████████| 79/79 [00:01<00:00, 77.03it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.2319


100%|██████████| 469/469 [00:06<00:00, 71.22it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 29 Train loss: 101.1949


100%|██████████| 79/79 [00:01<00:00, 77.12it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.2708


100%|██████████| 469/469 [00:06<00:00, 70.67it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 30 Train loss: 101.0606


100%|██████████| 79/79 [00:01<00:00, 72.84it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.4067


100%|██████████| 469/469 [00:11<00:00, 40.90it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 31 Train loss: 100.9643


100%|██████████| 79/79 [00:02<00:00, 27.98it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0708


100%|██████████| 469/469 [00:15<00:00, 30.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 32 Train loss: 100.8095


100%|██████████| 79/79 [00:02<00:00, 27.00it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0534


100%|██████████| 469/469 [00:15<00:00, 30.47it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 33 Train loss: 100.7044


100%|██████████| 79/79 [00:02<00:00, 28.47it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.1950


100%|██████████| 469/469 [00:15<00:00, 30.54it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 34 Train loss: 100.5666


100%|██████████| 79/79 [00:02<00:00, 27.71it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0960


100%|██████████| 469/469 [00:15<00:00, 30.35it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 35 Train loss: 100.4533


100%|██████████| 79/79 [00:02<00:00, 29.03it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9149


100%|██████████| 469/469 [00:15<00:00, 29.49it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 36 Train loss: 100.3888


100%|██████████| 79/79 [00:02<00:00, 28.29it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 104.0673


100%|██████████| 469/469 [00:15<00:00, 30.03it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 37 Train loss: 100.2926


100%|██████████| 79/79 [00:02<00:00, 27.47it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9416


100%|██████████| 469/469 [00:16<00:00, 29.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 38 Train loss: 100.2084


100%|██████████| 79/79 [00:02<00:00, 30.04it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.8869


100%|██████████| 469/469 [00:16<00:00, 29.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 39 Train loss: 100.1314


100%|██████████| 79/79 [00:02<00:00, 31.02it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.8073


100%|██████████| 469/469 [00:16<00:00, 32.15it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 40 Train loss: 100.0285


100%|██████████| 79/79 [00:02<00:00, 34.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.7653


100%|██████████| 469/469 [00:16<00:00, 28.84it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 41 Train loss: 99.8829


100%|██████████| 79/79 [00:02<00:00, 28.80it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9433


100%|██████████| 469/469 [00:16<00:00, 29.24it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 42 Train loss: 99.8498


100%|██████████| 79/79 [00:02<00:00, 32.11it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.7820


100%|██████████| 469/469 [00:16<00:00, 29.16it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 43 Train loss: 99.7240


100%|██████████| 79/79 [00:02<00:00, 29.39it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.6074


100%|██████████| 469/469 [00:15<00:00, 29.94it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 44 Train loss: 99.6710


100%|██████████| 79/79 [00:02<00:00, 30.95it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.8032


100%|██████████| 469/469 [00:15<00:00, 27.94it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 45 Train loss: 99.5592


100%|██████████| 79/79 [00:02<00:00, 28.31it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.7522


100%|██████████| 469/469 [00:15<00:00, 30.08it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 46 Train loss: 99.5191


100%|██████████| 79/79 [00:02<00:00, 27.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.8591


100%|██████████| 469/469 [00:15<00:00, 27.80it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 47 Train loss: 99.4623


100%|██████████| 79/79 [00:02<00:00, 29.39it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.4759


100%|██████████| 469/469 [00:15<00:00, 29.46it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 48 Train loss: 99.3298


100%|██████████| 79/79 [00:02<00:00, 27.67it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.7823


100%|██████████| 469/469 [00:15<00:00, 27.85it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 49 Train loss: 99.2813


100%|██████████| 79/79 [00:02<00:00, 29.92it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.6436


100%|██████████| 469/469 [00:15<00:00, 29.71it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 50 Train loss: 99.2217


100%|██████████| 79/79 [00:02<00:00, 32.99it/s]

Test loss: 103.4995



