In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli
from Tars.losses import KullbackLeibler
from Tars.models import VAE

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
x_dim = 784
z_dim = 64


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"])

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": F.sigmoid(self.fc3(h))}
    
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

In [5]:
p = Generator()
q = Inference()

p.to(device)
q.to(device)

Inference(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
  (fc31): Linear(in_features=512, out_features=64, bias=True)
  (fc32): Linear(in_features=512, out_features=64, bias=True)
)

In [6]:
kl = KullbackLeibler(q, prior)

In [7]:
model = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [8]:
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        loss = model.train({"x": data.view(-1, 784)})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        loss = model.test({"x": data.view(-1, 784)})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [10]:
def plot_reconstrunction(data):
    with torch.no_grad():
        z = q.sample({"x": data.view(-1, 784)}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([data, recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [11]:
writer = SummaryWriter()

z_sample = 0.5 * torch.randn(64, z_dim).to(device)
x_original, y_original = iter(test_loader).next()
x_original = x_original.to(device)
y_original = y_original.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(x_original[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    
writer.close()

100%|██████████| 469/469 [00:05<00:00, 85.98it/s]

Epoch: 1 Train loss: 173.3753



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 131.9921


100%|██████████| 469/469 [00:05<00:00, 85.88it/s]

Epoch: 2 Train loss: 120.1560



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 111.7963


100%|██████████| 469/469 [00:05<00:00, 90.15it/s]

Epoch: 3 Train loss: 106.8615



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 103.9233


100%|██████████| 469/469 [00:05<00:00, 92.53it/s]


Epoch: 4 Train loss: 100.6442


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 98.3785


100%|██████████| 469/469 [00:05<00:00, 90.12it/s]


Epoch: 5 Train loss: 97.0214


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 96.1433


100%|██████████| 469/469 [00:05<00:00, 86.66it/s]


Epoch: 6 Train loss: 94.7921


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 94.6992


100%|██████████| 469/469 [00:05<00:00, 84.82it/s]

Epoch: 7 Train loss: 93.2771



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 93.1864


100%|██████████| 469/469 [00:05<00:00, 88.61it/s]

Epoch: 8 Train loss: 92.1183



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 92.5052


100%|██████████| 469/469 [00:05<00:00, 86.61it/s]

Epoch: 9 Train loss: 91.2383



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 91.9623


100%|██████████| 469/469 [00:05<00:00, 92.24it/s]

Epoch: 10 Train loss: 90.4645



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 91.2512


100%|██████████| 469/469 [00:05<00:00, 80.09it/s]

Epoch: 11 Train loss: 89.8202



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 90.7999


100%|██████████| 469/469 [00:05<00:00, 90.64it/s]

Epoch: 12 Train loss: 89.2731



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 90.4077


100%|██████████| 469/469 [00:05<00:00, 86.16it/s]

Epoch: 13 Train loss: 88.7849



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 89.7735


100%|██████████| 469/469 [00:05<00:00, 86.05it/s]

Epoch: 14 Train loss: 88.3347



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 89.8244


100%|██████████| 469/469 [00:05<00:00, 86.68it/s]

Epoch: 15 Train loss: 88.0247



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 89.2636


100%|██████████| 469/469 [00:05<00:00, 85.29it/s]

Epoch: 16 Train loss: 87.6558



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 89.4142


100%|██████████| 469/469 [00:05<00:00, 90.04it/s]

Epoch: 17 Train loss: 87.3342



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 89.2592


100%|██████████| 469/469 [00:05<00:00, 90.46it/s]

Epoch: 18 Train loss: 87.1370



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.7934


100%|██████████| 469/469 [00:05<00:00, 86.51it/s]

Epoch: 19 Train loss: 86.8406



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.6980


100%|██████████| 469/469 [00:05<00:00, 90.33it/s]

Epoch: 20 Train loss: 86.6075



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.7881


100%|██████████| 469/469 [00:05<00:00, 91.95it/s]

Epoch: 21 Train loss: 86.4312



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.3607


100%|██████████| 469/469 [00:04<00:00, 94.73it/s]

Epoch: 22 Train loss: 86.2661



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.6472


100%|██████████| 469/469 [00:05<00:00, 87.56it/s]

Epoch: 23 Train loss: 86.0539



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.1764


100%|██████████| 469/469 [00:05<00:00, 91.95it/s]


Epoch: 24 Train loss: 85.8349


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 88.0565


100%|██████████| 469/469 [00:05<00:00, 93.53it/s]

Epoch: 25 Train loss: 85.7426



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.9629


100%|██████████| 469/469 [00:05<00:00, 88.80it/s]

Epoch: 26 Train loss: 85.5502



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.8864


100%|██████████| 469/469 [00:05<00:00, 89.77it/s]

Epoch: 27 Train loss: 85.4542



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.9345


100%|██████████| 469/469 [00:05<00:00, 88.77it/s]


Epoch: 28 Train loss: 85.3081


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.7444


100%|██████████| 469/469 [00:05<00:00, 84.09it/s]

Epoch: 29 Train loss: 85.1860



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.7464


100%|██████████| 469/469 [00:05<00:00, 84.24it/s]

Epoch: 30 Train loss: 85.0693



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.6087


100%|██████████| 469/469 [00:05<00:00, 92.18it/s]


Epoch: 31 Train loss: 84.9674


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.7600


100%|██████████| 469/469 [00:05<00:00, 89.06it/s]


Epoch: 32 Train loss: 84.8570


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4784


100%|██████████| 469/469 [00:05<00:00, 88.54it/s]

Epoch: 33 Train loss: 84.7186



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4771


100%|██████████| 469/469 [00:05<00:00, 82.28it/s]

Epoch: 34 Train loss: 84.6474



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.1865


100%|██████████| 469/469 [00:05<00:00, 91.23it/s]


Epoch: 35 Train loss: 84.5766


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4882


100%|██████████| 469/469 [00:04<00:00, 95.57it/s]

Epoch: 36 Train loss: 84.4395



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4806


100%|██████████| 469/469 [00:04<00:00, 94.43it/s]


Epoch: 37 Train loss: 84.4220


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4723


100%|██████████| 469/469 [00:05<00:00, 87.19it/s]

Epoch: 38 Train loss: 84.2513



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.2888


100%|██████████| 469/469 [00:05<00:00, 86.65it/s]


Epoch: 39 Train loss: 84.2195


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.4478


100%|██████████| 469/469 [00:05<00:00, 87.73it/s]

Epoch: 40 Train loss: 84.1103



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.1926


100%|██████████| 469/469 [00:05<00:00, 90.62it/s]

Epoch: 41 Train loss: 84.0454



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.3045


100%|██████████| 469/469 [00:05<00:00, 92.64it/s]

Epoch: 42 Train loss: 83.9556



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 86.9242


100%|██████████| 469/469 [00:05<00:00, 89.45it/s]

Epoch: 43 Train loss: 83.9095



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.0789


100%|██████████| 469/469 [00:05<00:00, 88.23it/s]

Epoch: 44 Train loss: 83.8391



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.2021


100%|██████████| 469/469 [00:05<00:00, 87.59it/s]

Epoch: 45 Train loss: 83.7887



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.3755


100%|██████████| 469/469 [00:05<00:00, 91.77it/s]

Epoch: 46 Train loss: 83.7290



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.1768


100%|██████████| 469/469 [00:05<00:00, 82.17it/s]

Epoch: 47 Train loss: 83.6373



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.2549


100%|██████████| 469/469 [00:05<00:00, 92.58it/s]


Epoch: 48 Train loss: 83.6225


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.1653


100%|██████████| 469/469 [00:05<00:00, 87.88it/s]

Epoch: 49 Train loss: 83.5318



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 87.2180


100%|██████████| 469/469 [00:04<00:00, 95.04it/s]


Epoch: 50 Train loss: 83.4636
Test loss: 86.9549
