In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import RealNVP, Normal
from Tars.models import ML
from Tars.utils import get_dict_values

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
import torch
z_dim = 784

loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

p = RealNVP(prior, var=["x"], in_features=z_dim, num_nn_layers=3, hidden_features=128, num_multiscale_layers=2, num_flow_layers=4, image=True)

In [5]:
p.to(device)

RealNVP(
  (dist): Normal()
  (flows): ModuleList(
    (0): MultiScaleLayer1D(
      (flows): ModuleList(
        (0): AffineCouplingLayer1D(
          in_features=784, pattern=0
          (layers): ModuleList(
            (0): Linear(in_features=784, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
            (2): Linear(in_features=128, out_features=1568, bias=True)
          )
          (batch_norms): ModuleList(
            (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): AffineCouplingLayer1D(
          in_features=784, pattern=1
          (layers): ModuleList(
            (0): Linear(in_features=784, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
            (2): Linear(in_features=128, out_features=1568, bias=T

In [6]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [7]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        lower_bound, loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [8]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        lower_bound, loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [9]:
def plot_reconstrunction(data):
    with torch.no_grad():
        z = p.sample_inv({"x": data.view(-1, 784)})
        recon_batch = p.sample(z, only_flow=True)["x"].view(-1, 1, 28, 28)
    
        comparison = torch.cat([data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample({"z":z_sample}, batch_size=64, only_flow=True)["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [10]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:13<00:00, 33.67it/s]

Epoch: 1 Train loss: -2567.9695



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3085.9036


100%|██████████| 469/469 [00:13<00:00, 34.29it/s]

Epoch: 2 Train loss: -3108.9805



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3186.9128


100%|██████████| 469/469 [00:13<00:00, 34.30it/s]

Epoch: 3 Train loss: -3168.9902



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3228.7324


100%|██████████| 469/469 [00:13<00:00, 34.57it/s]


Epoch: 4 Train loss: -3201.2239


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3254.9883


100%|██████████| 469/469 [00:13<00:00, 35.16it/s]


Epoch: 5 Train loss: -3223.8130


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3276.0022


100%|██████████| 469/469 [00:13<00:00, 34.74it/s]

Epoch: 6 Train loss: -3240.5449



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3291.6284


100%|██████████| 469/469 [00:13<00:00, 33.88it/s]

Epoch: 7 Train loss: -3253.4512



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3302.7354


100%|██████████| 469/469 [00:13<00:00, 35.37it/s]


Epoch: 8 Train loss: -3264.2551


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3314.0688


100%|██████████| 469/469 [00:13<00:00, 33.63it/s]

Epoch: 9 Train loss: -3273.5720



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3320.5750


100%|██████████| 469/469 [00:12<00:00, 36.11it/s]

Epoch: 10 Train loss: -3280.7041



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3327.2786


100%|██████████| 469/469 [00:13<00:00, 34.23it/s]


Epoch: 11 Train loss: -3287.2847


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3334.3728


100%|██████████| 469/469 [00:13<00:00, 33.80it/s]

Epoch: 12 Train loss: -3292.4814



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3339.3198


100%|██████████| 469/469 [00:13<00:00, 34.28it/s]


Epoch: 13 Train loss: -3297.9944


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3345.2268


100%|██████████| 469/469 [00:13<00:00, 34.55it/s]


Epoch: 14 Train loss: -3302.0837


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3348.3684


100%|██████████| 469/469 [00:13<00:00, 34.27it/s]


Epoch: 15 Train loss: -3306.6750


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3353.9448


100%|██████████| 469/469 [00:14<00:00, 33.26it/s]


Epoch: 16 Train loss: -3310.2783


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3356.1355


100%|██████████| 469/469 [00:13<00:00, 34.47it/s]


Epoch: 17 Train loss: -3313.2297


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3359.5078


100%|██████████| 469/469 [00:13<00:00, 34.04it/s]


Epoch: 18 Train loss: -3316.3401


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3363.4326


100%|██████████| 469/469 [00:13<00:00, 33.63it/s]


Epoch: 19 Train loss: -3319.1125


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3365.7014


100%|██████████| 469/469 [00:13<00:00, 35.09it/s]


Epoch: 20 Train loss: -3321.9207


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3367.8911


100%|██████████| 469/469 [00:13<00:00, 34.57it/s]


Epoch: 21 Train loss: -3323.8320


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3370.3616


100%|██████████| 469/469 [00:14<00:00, 31.93it/s]

Epoch: 22 Train loss: -3326.6377



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3372.1292


100%|██████████| 469/469 [00:13<00:00, 33.57it/s]


Epoch: 23 Train loss: -3328.4182


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3373.0488


100%|██████████| 469/469 [00:14<00:00, 33.09it/s]


Epoch: 24 Train loss: -3330.5645


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3376.6719


100%|██████████| 469/469 [00:13<00:00, 34.45it/s]


Epoch: 25 Train loss: -3332.4695


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3376.4424


100%|██████████| 469/469 [00:14<00:00, 32.86it/s]


Epoch: 26 Train loss: -3334.2061


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3379.5315


100%|██████████| 469/469 [00:14<00:00, 32.68it/s]

Epoch: 27 Train loss: -3335.4402



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3380.8611


100%|██████████| 469/469 [00:13<00:00, 34.31it/s]


Epoch: 28 Train loss: -3336.9033


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3383.2024


100%|██████████| 469/469 [00:13<00:00, 33.81it/s]


Epoch: 29 Train loss: -3338.6445


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3384.6992


100%|██████████| 469/469 [00:12<00:00, 36.33it/s]


Epoch: 30 Train loss: -3339.8472


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3385.7395


100%|██████████| 469/469 [00:14<00:00, 33.19it/s]

Epoch: 31 Train loss: -3341.5879



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3386.2502


100%|██████████| 469/469 [00:13<00:00, 35.05it/s]

Epoch: 32 Train loss: -3342.6516



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3388.3586


100%|██████████| 469/469 [00:14<00:00, 32.94it/s]


Epoch: 33 Train loss: -3343.4509


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3390.2834


100%|██████████| 469/469 [00:13<00:00, 33.80it/s]


Epoch: 34 Train loss: -3344.7644


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3390.2720


100%|██████████| 469/469 [00:14<00:00, 33.33it/s]


Epoch: 35 Train loss: -3346.1106


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3392.1184


100%|██████████| 469/469 [00:14<00:00, 33.43it/s]

Epoch: 36 Train loss: -3347.2258



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3393.3662


100%|██████████| 469/469 [00:14<00:00, 33.20it/s]


Epoch: 37 Train loss: -3347.8228


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3393.5872


100%|██████████| 469/469 [00:14<00:00, 32.59it/s]


Epoch: 38 Train loss: -3348.6125


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3394.1311


100%|██████████| 469/469 [00:13<00:00, 33.57it/s]


Epoch: 39 Train loss: -3349.7585


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3394.9199


100%|██████████| 469/469 [00:14<00:00, 31.88it/s]


Epoch: 40 Train loss: -3350.6326


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3397.6567


100%|██████████| 469/469 [00:13<00:00, 34.12it/s]


Epoch: 41 Train loss: -3351.3582


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3396.8347


100%|██████████| 469/469 [00:14<00:00, 33.44it/s]


Epoch: 42 Train loss: -3352.5769


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3397.3330


100%|██████████| 469/469 [00:14<00:00, 32.94it/s]


Epoch: 43 Train loss: -3352.8164


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3398.4800


100%|██████████| 469/469 [00:13<00:00, 35.06it/s]


Epoch: 44 Train loss: -3353.7576


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3399.6172


100%|██████████| 469/469 [00:13<00:00, 34.17it/s]


Epoch: 45 Train loss: -3354.3616


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3400.1716


100%|██████████| 469/469 [00:13<00:00, 34.38it/s]


Epoch: 46 Train loss: -3355.0481


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3401.6506


100%|██████████| 469/469 [00:13<00:00, 34.04it/s]


Epoch: 47 Train loss: -3355.8784


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3400.0046


100%|██████████| 469/469 [00:13<00:00, 34.42it/s]


Epoch: 48 Train loss: -3356.9507


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3401.8516


100%|██████████| 469/469 [00:13<00:00, 33.90it/s]


Epoch: 49 Train loss: -3357.2649


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3402.6118


100%|██████████| 469/469 [00:13<00:00, 34.91it/s]


Epoch: 50 Train loss: -3357.6099
Test loss: -3402.9399
