# Real NVP (using the ML class)

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import RealNVP, Normal
from pixyz.models import ML
from pixyz.utils import get_dict_values

In [4]:
import torch
z_dim = 784

loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

p = RealNVP(prior, var=["x"], dim=z_dim, num_nn_layers=3, hidden_features=128, num_multiscale_layers=2, num_flow_layers=4, image=True)
p.to(device)

print(p)

Distribution:
  p(x=RealNVP(z)) (None)
Network architecture:
  RealNVP(
    (prior): Normal()
    (flows): ModuleList(
      (0): MultiScaleLayer1D(
        (flows): ModuleList(
          (0): AffineCouplingLayer1D(
            in_features=784, pattern=0
            (layers): ModuleList(
              (0): Linear(in_features=784, out_features=128, bias=True)
              (1): Linear(in_features=128, out_features=128, bias=True)
              (2): Linear(in_features=128, out_features=1568, bias=True)
            )
            (batch_norms): ModuleList(
              (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (1): AffineCouplingLayer1D(
            in_features=784, pattern=1
            (layers): ModuleList(
              (0): Linear(in_features=784, out_features=128, bias=True)
              (1): Linear(in_fea

In [5]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [6]:
def train(epoch):
    train_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [7]:
def test(epoch):
    test_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [8]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = p.sample_inv({"x": x})
        recon_batch = p.sample(z, only_flow=True, return_all=False)["x"].view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample({"z":z_sample}, only_flow=True, return_all=False)["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [9]:
writer = SummaryWriter()

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:11<00:00, 39.29it/s]

Epoch: 1 Train loss: -2566.2290



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3088.5864


100%|██████████| 469/469 [00:11<00:00, 40.37it/s]

Epoch: 2 Train loss: -3112.4900



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3191.1558


100%|██████████| 469/469 [00:11<00:00, 40.03it/s]


Epoch: 3 Train loss: -3171.8572


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3231.2468


100%|██████████| 469/469 [00:11<00:00, 40.12it/s]


Epoch: 4 Train loss: -3203.5776


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3257.1323


100%|██████████| 469/469 [00:11<00:00, 39.99it/s]

Epoch: 5 Train loss: -3225.6096





Test loss: -3276.9561
