# Conditional Real NVP

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Normal, InverseTransformedDistribution
from pixyz.flows import AffineCoupling, FlowList, BatchNorm1d, Shuffle, Preprocess, Reverse
from pixyz.models import ML
from pixyz.utils import print_latex

In [4]:
x_dim = 28*28
y_dim = 10
z_dim = x_dim

In [5]:
# prior model p(z)
prior = Normal(loc= torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_prior").to(device)

In [6]:
class ScaleTranslateNet(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features+y_dim, hidden_features)
        self.fc2 = nn.Linear(hidden_features, hidden_features)
        self.fc3_s = nn.Linear(hidden_features, in_features)
        self.fc3_t = nn.Linear(hidden_features, in_features)

    def forward(self, x, y):
        hidden = F.relu(self.fc2(F.relu(self.fc1(torch.cat([x, y], 1)))))
        log_s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return log_s, t

In [7]:
# flow
flow_list = []
num_block = 5

flow_list.append(Preprocess())

for i in range(num_block):
        flow_list.append(AffineCoupling(in_features=x_dim,
                                        scale_translate_net=ScaleTranslateNet(x_dim, 1028),
                                        inverse_mask=(i%2!=0)))
        
        flow_list.append(BatchNorm1d(x_dim))
        
f = FlowList(flow_list)

In [8]:
# inverse transformed distribution (z -> f^-1 -> x)
p = InverseTransformedDistribution(prior=prior, flow=f, var=["x"], cond_var=["y"]).to(device)
print_latex(p)

<IPython.core.display.Math object>

In [9]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x|y) 
Loss function: 
  mean \left(- \log p(x|y) \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [10]:
def train(epoch):
    train_loss = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.train({"x": x, "y": y})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, y in test_loader:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        loss = model.test({"x": x, "y": y})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [12]:
def plot_reconstrunction(x, y):
    with torch.no_grad():
        z = p.forward(x, y, compute_jacobian=False)
        recon_batch = p.inverse(z, y).view(-1, 1, 28, 28)
    
        recon = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return recon
    
def plot_image_from_latent(z, y):
    with torch.no_grad():
        sample = p.inverse(z, y).view(-1, 1, 28, 28).cpu()
        return sample
    
def plot_reconstrunction_changing_y(x, y):
    y_change = torch.eye(10)[range(7)].to(device)
    batch_dummy = torch.ones(x.size(0))[:, None].to(device)    
    recon_all = []
    
    with torch.no_grad():
        for _y in y_change:
            z = p.forward(x, y, compute_jacobian=False)
            recon_batch = p.inverse(z, batch_dummy * _y[None,:]).view(-1, 1, 28, 28)
            recon_all.append(recon_batch)
    
        recon_changing_y = torch.cat(recon_all)
        recon_changing_y = torch.cat([x.view(-1, 1, 28, 28), recon_changing_y]).cpu()
        return recon_changing_y

In [13]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [14]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".real_nvp_cond" + exp_time)

plot_number = 5

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
y_sample = torch.eye(10)[[plot_number]*64].to(device)

_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = torch.eye(10)[_y].to(device)

import time
start = time.time()
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8], _y[:8])
    sample = plot_image_from_latent(z_sample, y_sample)
    recon_changing_y = plot_reconstrunction_changing_y(_x[:8], _y[:8])

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
    writer.add_images('Image_reconstrunction_change_y', recon_changing_y, epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time) 
writer.close()

100%|██████████| 469/469 [00:17<00:00, 27.19it/s]

Epoch: 1 Train loss: -2910.2405



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2589.7812


100%|██████████| 469/469 [00:17<00:00, 26.93it/s]

Epoch: 2 Train loss: -3100.3530



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2839.2712


100%|██████████| 469/469 [00:16<00:00, 27.70it/s]


Epoch: 3 Train loss: -3158.2097


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2991.2817


100%|██████████| 469/469 [00:17<00:00, 27.36it/s]

Epoch: 4 Train loss: -3191.9492



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2972.3020


100%|██████████| 469/469 [00:17<00:00, 27.29it/s]

Epoch: 5 Train loss: -3213.2173



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2873.9517


100%|██████████| 469/469 [00:17<00:00, 27.53it/s]

Epoch: 6 Train loss: -3231.2964



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2744.0271


100%|██████████| 469/469 [00:17<00:00, 27.22it/s]

Epoch: 7 Train loss: -3244.7383



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -2981.4692


100%|██████████| 469/469 [00:18<00:00, 26.04it/s]

Epoch: 8 Train loss: -3255.3843



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3144.8003


100%|██████████| 469/469 [00:18<00:00, 25.78it/s]


Epoch: 9 Train loss: -3265.3008


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -1612.7463


100%|██████████| 469/469 [00:17<00:00, 26.53it/s]

Epoch: 10 Train loss: -3273.2192





Test loss: -3091.5256
