# Real NVP

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda:1"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 4, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, InverseTransformedDistribution
from pixyz.flows import AffineCoupling, FlowList, BatchNorm1d, Shuffle, Preprocess, Reverse
from pixyz.models import ML

In [4]:
x_dim = 28*28
z_dim = x_dim

In [5]:
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [6]:
class ScaleTranslateNet(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, hidden_features)
        self.fc3_s = nn.Linear(hidden_features, in_features)
        self.fc3_t = nn.Linear(hidden_features, in_features)

    def forward(self, x):
        hidden = F.relu(self.fc2(F.relu(self.fc1(x))))
        log_s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return log_s, t

In [7]:
# flow
flow_list = []
num_block = 5

flow_list.append(Preprocess())

for i in range(num_block):
        flow_list.append(AffineCoupling(in_features=x_dim,
                                        scale_translate_net=ScaleTranslateNet(x_dim, 1028),
                                        inverse_mask=(i%2!=0)))
        
        flow_list.append(BatchNorm1d(x_dim))
        
f = FlowList(flow_list)

In [8]:
# inverse transformed distribution (z -> f^-1 -> x)
p = InverseTransformedDistribution(prior=prior, flow=f, var=["x"])
p.to(device)

InverseTransformedDistribution(
  (prior): Normal()
  (flow): FlowList(
    (0): Preprocess()
    (1): AffineCoupling(
      in_features=784, mask_type=channel_wise, inverse_mask=False
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=784, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
    (2): BatchNorm1d()
    (3): AffineCoupling(
      in_features=784, mask_type=channel_wise, inverse_mask=True
      (scale_translate_net): ScaleTranslateNet(
        (fc1): Linear(in_features=784, out_features=1028, bias=True)
        (fc2): Linear(in_features=1028, out_features=1028, bias=True)
        (fc3_s): Linear(in_features=1028, out_features=784, bias=True)
        (fc3_t): Linear(in_features=1028, out_features=784, bias=True)
      )
    )
    (4): BatchNo

In [9]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(-(log p(x))) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [10]:
def train(epoch):
    train_loss = 0
    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss

    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [12]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = p.forward(x, compute_jacobian=False)
        recon_batch = p.inverse(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.inverse(z_sample).view(-1, 1, 28, 28).cpu()
        return sample

In [13]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:07<00:00, 59.28it/s]

Epoch: 1 Train loss: 1445.6053



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1924.4038


100%|██████████| 469/469 [00:07<00:00, 62.75it/s]


Epoch: 2 Train loss: 1253.4928


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1578.1617


100%|██████████| 469/469 [00:07<00:00, 60.92it/s]


Epoch: 3 Train loss: 1195.0905


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1538.3510


100%|██████████| 469/469 [00:07<00:00, 62.03it/s]


Epoch: 4 Train loss: 1161.3428


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2328.1995


100%|██████████| 469/469 [00:08<00:00, 55.40it/s]

Epoch: 5 Train loss: 1138.2686



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1419.9275


100%|██████████| 469/469 [00:07<00:00, 62.71it/s]

Epoch: 6 Train loss: 1123.0366



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1364.7207


100%|██████████| 469/469 [00:07<00:00, 62.56it/s]


Epoch: 7 Train loss: 1108.6436


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1379.2495


100%|██████████| 469/469 [00:07<00:00, 59.85it/s]


Epoch: 8 Train loss: 1097.9465


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1285.3647


100%|██████████| 469/469 [00:08<00:00, 55.53it/s]

Epoch: 9 Train loss: 1088.0216



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1383.1555


100%|██████████| 469/469 [00:11<00:00, 28.61it/s]


Epoch: 10 Train loss: 1080.9961


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1452.1449


100%|██████████| 469/469 [00:16<00:00, 28.58it/s]

Epoch: 11 Train loss: 1072.6532



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1276.0413


100%|██████████| 469/469 [00:16<00:00, 28.91it/s]

Epoch: 12 Train loss: 1066.7667



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1305.5505


100%|██████████| 469/469 [00:16<00:00, 28.75it/s]

Epoch: 13 Train loss: 1062.9041



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1233.1605


100%|██████████| 469/469 [00:16<00:00, 28.20it/s]

Epoch: 14 Train loss: 1056.3188



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1256.2241


100%|██████████| 469/469 [00:16<00:00, 28.22it/s]


Epoch: 15 Train loss: 1052.5262


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1296.7896


100%|██████████| 469/469 [00:17<00:00, 26.51it/s]


Epoch: 16 Train loss: 1047.8396


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1335.0488


100%|██████████| 469/469 [00:17<00:00, 26.66it/s]


Epoch: 17 Train loss: 1043.9014


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1287.2756


 87%|████████▋ | 406/469 [00:15<00:02, 27.58it/s]Exception in thread Thread-58:
Traceback (most recent call last):
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/threading.py", line 917, in _bootstrap_inner
    self.run()
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/threading.py", line 865, in run
    self._target(*self._args, **self._kwargs)
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 21, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 276, in rebuild_storage_fd
    fd = df.detach()
  File "/home/masa/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/multiprocessing/resource_sharer

KeyboardInterrupt: 