In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import RealNVP, Normal
from Tars.models import ML
from Tars.utils import get_dict_values

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
import torch
z_dim = 784

loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

p = RealNVP(prior, var=["x"], in_features=z_dim, num_nn_layers=3, hidden_features=128, num_multiscale_layers=2, num_flow_layers=4, image=True)

In [5]:
p.to(device)

RealNVP(
  (dist): Normal()
  (flows): ModuleList(
    (0): MultiScaleLayer1D(
      (flows): ModuleList(
        (0): AffineCouplingLayer1D(
          in_features=784, pattern=0
          (layers): ModuleList(
            (0): Linear(in_features=784, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
            (2): Linear(in_features=128, out_features=1568, bias=True)
          )
          (batch_norms): ModuleList(
            (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): AffineCouplingLayer1D(
          in_features=784, pattern=1
          (layers): ModuleList(
            (0): Linear(in_features=784, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
            (2): Linear(in_features=128, out_features=1568, bias=T

In [6]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [7]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        lower_bound, loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [8]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        lower_bound, loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [9]:
def plot_reconstrunction(data):
    with torch.no_grad():
        z = p.sample_inv({"x": data.view(-1, 784)})
        recon_batch = p.sample(z, only_flow=True)["x"].view(-1, 1, 28, 28)
    
        comparison = torch.cat([data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample({"z":z_sample}, batch_size=64, only_flow=True)["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [10]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:17<00:00, 27.40it/s]

Epoch: 1 Train loss: -2567.9695



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3085.9038


100%|██████████| 469/469 [00:16<00:00, 28.80it/s]


Epoch: 2 Train loss: -3108.9768


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3187.0439


100%|██████████| 469/469 [00:16<00:00, 28.59it/s]


Epoch: 3 Train loss: -3168.9814


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3228.7153


100%|██████████| 469/469 [00:15<00:00, 30.49it/s]


Epoch: 4 Train loss: -3201.2483


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3255.1375


100%|██████████| 469/469 [00:14<00:00, 31.50it/s]


Epoch: 5 Train loss: -3223.8127


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3275.9392


100%|██████████| 469/469 [00:15<00:00, 30.11it/s]


Epoch: 6 Train loss: -3240.5552


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3291.6694


100%|██████████| 469/469 [00:15<00:00, 30.54it/s]


Epoch: 7 Train loss: -3253.4851


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3302.7915


100%|██████████| 469/469 [00:14<00:00, 31.95it/s]


Epoch: 8 Train loss: -3264.2896


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3314.1008


100%|██████████| 469/469 [00:15<00:00, 30.72it/s]


Epoch: 9 Train loss: -3273.5837


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3320.5833


100%|██████████| 469/469 [00:15<00:00, 30.33it/s]


Epoch: 10 Train loss: -3280.7205


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3327.1350


100%|██████████| 469/469 [00:15<00:00, 30.96it/s]


Epoch: 11 Train loss: -3287.3193


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3334.3677


100%|██████████| 469/469 [00:14<00:00, 31.68it/s]

Epoch: 12 Train loss: -3292.5337



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3339.3689


100%|██████████| 469/469 [00:15<00:00, 29.77it/s]


Epoch: 13 Train loss: -3298.0535


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3345.1772


100%|██████████| 469/469 [00:15<00:00, 31.18it/s]


Epoch: 14 Train loss: -3302.1680


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3348.4463


100%|██████████| 469/469 [00:15<00:00, 29.34it/s]

Epoch: 15 Train loss: -3306.6646



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3353.7434


100%|██████████| 469/469 [00:14<00:00, 31.33it/s]


Epoch: 16 Train loss: -3310.2876


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3356.1448


100%|██████████| 469/469 [00:14<00:00, 33.04it/s]


Epoch: 17 Train loss: -3313.1897


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3359.5315


100%|██████████| 469/469 [00:14<00:00, 32.85it/s]


Epoch: 18 Train loss: -3316.3232


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3363.5000


100%|██████████| 469/469 [00:14<00:00, 31.82it/s]


Epoch: 19 Train loss: -3319.0955


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3365.9082


100%|██████████| 469/469 [00:15<00:00, 30.10it/s]


Epoch: 20 Train loss: -3321.8789


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3367.9204


100%|██████████| 469/469 [00:15<00:00, 30.72it/s]


Epoch: 21 Train loss: -3323.7595


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3370.3176


100%|██████████| 469/469 [00:16<00:00, 28.54it/s]


Epoch: 22 Train loss: -3326.6165


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3372.1106


100%|██████████| 469/469 [00:14<00:00, 32.93it/s]


Epoch: 23 Train loss: -3328.4038


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3373.0930


100%|██████████| 469/469 [00:15<00:00, 30.65it/s]


Epoch: 24 Train loss: -3330.5718


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3376.8799


100%|██████████| 469/469 [00:15<00:00, 30.66it/s]


Epoch: 25 Train loss: -3332.5059


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3376.6240


100%|██████████| 469/469 [00:15<00:00, 30.27it/s]


Epoch: 26 Train loss: -3334.2529


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3379.7332


100%|██████████| 469/469 [00:15<00:00, 31.00it/s]


Epoch: 27 Train loss: -3335.5002


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3380.9536


100%|██████████| 469/469 [00:14<00:00, 32.73it/s]


Epoch: 28 Train loss: -3336.9714


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3383.2107


100%|██████████| 469/469 [00:15<00:00, 30.13it/s]


Epoch: 29 Train loss: -3338.7776


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3384.5591


100%|██████████| 469/469 [00:15<00:00, 31.07it/s]


Epoch: 30 Train loss: -3339.9631


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3385.7864


100%|██████████| 469/469 [00:15<00:00, 30.48it/s]


Epoch: 31 Train loss: -3341.7151


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3386.3472


100%|██████████| 469/469 [00:15<00:00, 29.40it/s]


Epoch: 32 Train loss: -3342.8098


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3388.5654


100%|██████████| 469/469 [00:14<00:00, 31.37it/s]


Epoch: 33 Train loss: -3343.5840


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3390.5247


100%|██████████| 469/469 [00:14<00:00, 31.35it/s]


Epoch: 34 Train loss: -3345.0017


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3390.4624


100%|██████████| 469/469 [00:16<00:00, 28.41it/s]


Epoch: 35 Train loss: -3346.2695


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3392.3867


100%|██████████| 469/469 [00:16<00:00, 28.90it/s]


Epoch: 36 Train loss: -3347.4602


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3393.7634


100%|██████████| 469/469 [00:15<00:00, 30.15it/s]


Epoch: 37 Train loss: -3348.0076


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3393.8032


100%|██████████| 469/469 [00:15<00:00, 29.51it/s]


Epoch: 38 Train loss: -3348.8303


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3394.0720


100%|██████████| 469/469 [00:16<00:00, 28.24it/s]


Epoch: 39 Train loss: -3349.9426


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3395.2175


100%|██████████| 469/469 [00:15<00:00, 29.32it/s]


Epoch: 40 Train loss: -3350.8569


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3397.8203


100%|██████████| 469/469 [00:15<00:00, 30.11it/s]


Epoch: 41 Train loss: -3351.5134


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3396.7622


100%|██████████| 469/469 [00:15<00:00, 30.13it/s]


Epoch: 42 Train loss: -3352.7805


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3397.6423


100%|██████████| 469/469 [00:15<00:00, 30.32it/s]


Epoch: 43 Train loss: -3353.0374


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3398.7935


100%|██████████| 469/469 [00:15<00:00, 29.34it/s]


Epoch: 44 Train loss: -3353.9500


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3399.5723


100%|██████████| 469/469 [00:15<00:00, 30.52it/s]

Epoch: 45 Train loss: -3354.5142



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3400.2488


100%|██████████| 469/469 [00:15<00:00, 30.21it/s]


Epoch: 46 Train loss: -3355.2141


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3401.9307


100%|██████████| 469/469 [00:14<00:00, 31.55it/s]


Epoch: 47 Train loss: -3356.0544


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3400.3831


100%|██████████| 469/469 [00:14<00:00, 32.56it/s]


Epoch: 48 Train loss: -3357.0891


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3402.0854


100%|██████████| 469/469 [00:15<00:00, 30.05it/s]


Epoch: 49 Train loss: -3357.3765


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: -3403.0464


100%|██████████| 469/469 [00:15<00:00, 30.08it/s]


Epoch: 50 Train loss: -3357.7859
Test loss: -3403.0200
