In [1]:
import torch
import math

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
BATCH_SIZE=100

In [3]:
transform=transforms.Compose([
        transforms.ToTensor(),
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
eval_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
class IAF_Layer(nn.Module):
    def __init__(self,  input_dim, hidden_dim, out_dim, reverse=False):
        super(IAF_Layer, self).__init__()
        
        self.m_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Sigmoid(),
            nn.Linear(hidden_dim, out_dim)
            )
        self.s_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Sigmoid(),
            nn.Linear(hidden_dim, out_dim)
            )
        
        self.sigmoid = nn.Sigmoid()
        
        self.reverse = reverse
    
    def forward(self, z, h):
        
#         print(z.shape)
#         exit()
#         if self.reverse:
#             z =  torch.flip(z, dims=-1)
        
        iaf_input = torch.cat([z, h], dim=-1)
        m = self.m_layer(iaf_input)
        s = self.s_layer(iaf_input)
        sigma = self.sigmoid(s)
    
        return z * sigma + m * (1. - sigma), sigma

In [5]:
class IAF_VAE(nn.Module):
    def __init__(self, x_dim, h1, h2, z_dim, num_iaf):
        super(IAF_VAE, self).__init__()
        
        self.enc = nn.Sequential(
            nn.Linear(x_dim, h1),
            nn.ReLU(),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, z_dim),
            nn.ReLU()
        )
        
        self.enc_mu = nn.Linear(z_dim, z_dim)
        self.enc_log_var = nn.Linear(z_dim, z_dim)
        
        iaf_layers = []
        for i in range(num_iaf):
            reverse = i > 0 
            iaf_layers.append(IAF_Layer(z_dim*2, z_dim*2, z_dim, reverse=reverse))
        self.iaf_layers = nn.ModuleList(iaf_layers)
        
        self.dec = nn.Sequential(
            nn.Linear(z_dim, h2),
            nn.ReLU(),
            nn.Linear(h2, h1),
            nn.ReLU(),
            nn.Linear(h1, x_dim),
            nn.Sigmoid()
            )
    
    def forward(self, x):
        # original z_0
        h = self.enc(x)
        mu, log_var = self.enc_mu(h), self.enc_log_var(h)
        z_0, eps_0 = self.sampling(mu, log_var)
        
        z_t = z_0
        z_sigmas = []
        
        # iaf flows 
        for iaf in self.iaf_layers:
            z_t, sigma_z_t = iaf(z_t, h)
            z_sigmas.append(sigma_z_t)
        
        z = z_t 
        
        return self.dec(z), mu, log_var, z_sigmas, eps_0, z
        
    def sampling(self, mu, log_var):
        # reparametrization trick
        std = torch.exp(0.5*log_var)
        eps = torch.rand_like(std)
        return mu + (eps * std), eps

vae = IAF_VAE(x_dim=784, h1=512, h2=256, z_dim=2, num_iaf=4)

In [6]:
input_sample = torch.randn(10, 784)

In [7]:
res, mu, log_var, z_sigmas, eps_0, z_t = vae(input_sample)

In [8]:
def loss_fn(recon_x, x, mu, log_var, z_sigmas, eps_0, z):
    # reconstruction loss : binary cross entropy 
    # logp(x|z)
    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # -logp(z)
    logpz = torch.sum(0.5 * np.log(2*math.pi) + 0.5*z**2)
    
    # logq(z|x)
    det = log_var 
    for z_sigma in z_sigmas:
        det += torch.log(z_sigma)
    logqz_x = torch.sum(0.5 * np.log(2*math.pi) +0.5*eps_0**2 + det)
    
    return bce_loss + logpz - logqz_x

In [9]:
optimizer = torch.optim.Adam(vae.parameters())

In [10]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_ind, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.view(BATCH_SIZE, -1)
        recon_x, mu, log_var, z_sigmas, eps_0, z_t = vae(data)
        loss = loss_fn(recon_x, data, mu, log_var, z_sigmas, eps_0, z_t)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_ind % 200 == 0:
            print('Train Epoch:{} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_ind*len(data), len(train_loader.dataset),
                100*batch_ind/len(train_loader), loss.item()/len(data)))
        
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss/len(train_loader.dataset)))

In [11]:
def evaluation(): 
    vae.eval()
    eval_loss = 0
    with torch.no_grad():
        for data, _ in eval_loader:
#             data = data.cuda()
            data = data.view(BATCH_SIZE, -1)
            recon_x, mu, log_var, z_sigmas, eps_0, z_t = vae(data)
            eval_loss += loss_fn(recon_x, data, mu, log_var, z_sigmas, eps_0, z_t)
            
    eval_loss /= len(eval_loader.dataset)
    print('====> Evaluation loss : {:.4f}'.format(eval_loss))   
    return eval_loss

In [12]:
best_loss = float('inf')
for epoch in range(1, 200):
    train(epoch)
    eval_loss = evaluation()
    if eval_loss < best_loss:
        torch.save(vae.state_dict(), 'iaf_vae_best.pth')
        best_loss = eval_loss

====> Epoch: 1 Average loss: 191.6321
====> Evaluation loss : 177.9373
====> Epoch: 2 Average loss: 173.9770
====> Evaluation loss : 169.2890
====> Epoch: 3 Average loss: 167.3482
====> Evaluation loss : 163.7703
====> Epoch: 4 Average loss: 162.1096
====> Evaluation loss : 160.1425
====> Epoch: 5 Average loss: 159.0530
====> Evaluation loss : 156.2436
====> Epoch: 6 Average loss: 157.1101
====> Evaluation loss : 156.4203
====> Epoch: 7 Average loss: 154.2788
====> Evaluation loss : 152.7498
====> Epoch: 8 Average loss: 154.6109
====> Evaluation loss : 169.0733
====> Epoch: 9 Average loss: 155.2323
====> Evaluation loss : 151.5990
====> Epoch: 10 Average loss: 153.1188
====> Evaluation loss : 151.7772
====> Epoch: 11 Average loss: 151.6051
====> Evaluation loss : 150.5365
====> Epoch: 12 Average loss: 150.7840
====> Evaluation loss : 149.0562
====> Epoch: 13 Average loss: 160.5655
====> Evaluation loss : 163.4624
====> Epoch: 14 Average loss: 159.5008
====> Evaluation loss : 153.8279
=

====> Evaluation loss : 148.5002
====> Epoch: 35 Average loss: 149.3624
====> Evaluation loss : 148.3037
====> Epoch: 36 Average loss: 151.4378
====> Evaluation loss : 148.5747
====> Epoch: 37 Average loss: 147.6577
====> Evaluation loss : 145.9325
====> Epoch: 38 Average loss: 150.8191
====> Evaluation loss : 146.1888
====> Epoch: 39 Average loss: 146.4305
====> Evaluation loss : 145.5900
====> Epoch: 40 Average loss: 146.2857
====> Evaluation loss : 146.0723
====> Epoch: 41 Average loss: 144.4278
====> Evaluation loss : 146.0159
====> Epoch: 42 Average loss: 148.8314
====> Evaluation loss : 159.2484
====> Epoch: 43 Average loss: 148.4235
====> Evaluation loss : 143.9553
====> Epoch: 44 Average loss: 143.0875
====> Evaluation loss : 142.0733
====> Epoch: 45 Average loss: 142.6440
====> Evaluation loss : 140.8592
====> Epoch: 46 Average loss: 144.6585
====> Evaluation loss : 165.2575
====> Epoch: 47 Average loss: 143.6774
====> Evaluation loss : 141.3428
====> Epoch: 48 Average loss: 1

====> Epoch: 68 Average loss: 157.3677
====> Evaluation loss : 149.9432
====> Epoch: 69 Average loss: 148.5510
====> Evaluation loss : 147.2831
====> Epoch: 70 Average loss: 155.0468
====> Evaluation loss : 154.4161
====> Epoch: 71 Average loss: 148.4194
====> Evaluation loss : 143.3674
====> Epoch: 72 Average loss: 144.4333
====> Evaluation loss : 141.3363
====> Epoch: 73 Average loss: 141.3332
====> Evaluation loss : 141.1292
====> Epoch: 74 Average loss: 140.2730
====> Evaluation loss : 138.1826
====> Epoch: 75 Average loss: 138.7956
====> Evaluation loss : 136.9472
====> Epoch: 76 Average loss: 161.6308
====> Evaluation loss : 153.5653
====> Epoch: 77 Average loss: 145.0847
====> Evaluation loss : 138.5756
====> Epoch: 78 Average loss: 139.0193
====> Evaluation loss : 139.5796
====> Epoch: 79 Average loss: 136.7532
====> Evaluation loss : 135.4464
====> Epoch: 80 Average loss: 139.3648
====> Evaluation loss : 142.0151
====> Epoch: 81 Average loss: 153.1090
====> Evaluation loss : 1

====> Epoch: 102 Average loss: 139.0070
====> Evaluation loss : 135.0565
====> Epoch: 103 Average loss: 135.5384
====> Evaluation loss : 134.5070
====> Epoch: 104 Average loss: 136.4206
====> Evaluation loss : 132.9056
====> Epoch: 105 Average loss: 136.4821
====> Evaluation loss : 133.3910
====> Epoch: 106 Average loss: 134.2079
====> Evaluation loss : 136.0700
====> Epoch: 107 Average loss: 140.9808
====> Evaluation loss : 163.7609
====> Epoch: 108 Average loss: 147.1605
====> Evaluation loss : 140.7877
====> Epoch: 109 Average loss: 142.0658
====> Evaluation loss : 142.0564
====> Epoch: 110 Average loss: 138.7807
====> Evaluation loss : 138.5300
====> Epoch: 111 Average loss: 140.6428
====> Evaluation loss : 136.5687
====> Epoch: 112 Average loss: 140.4554
====> Evaluation loss : 133.6335
====> Epoch: 113 Average loss: 133.8440
====> Evaluation loss : 132.3180
====> Epoch: 114 Average loss: 132.4899
====> Evaluation loss : 135.6213
====> Epoch: 115 Average loss: 132.4158
====> Evalu

====> Epoch: 135 Average loss: 122.9564
====> Evaluation loss : 122.8700
====> Epoch: 136 Average loss: 122.1296
====> Evaluation loss : 123.2038
====> Epoch: 137 Average loss: 121.3809
====> Evaluation loss : 119.6129
====> Epoch: 138 Average loss: 120.5838
====> Evaluation loss : 118.5706
====> Epoch: 139 Average loss: 119.4851
====> Evaluation loss : 127.2184
====> Epoch: 140 Average loss: 119.1859
====> Evaluation loss : 122.4073
====> Epoch: 141 Average loss: 117.4532
====> Evaluation loss : 121.7326
====> Epoch: 142 Average loss: 116.9768
====> Evaluation loss : 114.4874
====> Epoch: 143 Average loss: 115.1553
====> Evaluation loss : 113.2673
====> Epoch: 144 Average loss: 115.1124
====> Evaluation loss : 122.4866
====> Epoch: 145 Average loss: 113.7419
====> Evaluation loss : 112.1170
====> Epoch: 146 Average loss: 112.3628
====> Evaluation loss : 109.6151
====> Epoch: 147 Average loss: 111.8868
====> Evaluation loss : 109.1483
====> Epoch: 148 Average loss: 110.6923
====> Evalu

====> Epoch: 168 Average loss: 85.1402
====> Evaluation loss : 81.8840
====> Epoch: 169 Average loss: 83.1685
====> Evaluation loss : 80.6585
====> Epoch: 170 Average loss: 81.8494
====> Evaluation loss : 79.3610
====> Epoch: 171 Average loss: 80.5863
====> Evaluation loss : 77.9117
====> Epoch: 172 Average loss: 78.6571
====> Evaluation loss : 77.0451
====> Epoch: 173 Average loss: 77.1588
====> Evaluation loss : 74.3208
====> Epoch: 174 Average loss: 76.2520
====> Evaluation loss : 72.8181
====> Epoch: 175 Average loss: 74.1666
====> Evaluation loss : 72.2459
====> Epoch: 176 Average loss: 74.1624
====> Evaluation loss : 71.8713
====> Epoch: 177 Average loss: 72.2102
====> Evaluation loss : 68.5015
====> Epoch: 178 Average loss: 70.9655
====> Evaluation loss : 67.3255
====> Epoch: 179 Average loss: 70.1648
====> Evaluation loss : 68.2873
====> Epoch: 180 Average loss: 67.8679
====> Evaluation loss : 67.0788
====> Epoch: 181 Average loss: 66.4836
====> Evaluation loss : 66.9290
====> 

RuntimeError: all elements of input should be between 0 and 1

In [13]:
vae.load_state_dict(torch.load('iaf_vae_best.pth'))
_ = vae.eval()

In [14]:
with torch.no_grad():
    for batch_ind, (data, _) in enumerate(eval_loader):
#             data = data.cuda()
        
        data = data.view(BATCH_SIZE, -1)
        recon_batch, *_ = vae(data)
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n].view(8, 1, 28, 28)[:n],
                                    recon_batch.unsqueeze(-1).view(BATCH_SIZE, 1, 28, 28)[:n]])
        save_image(comparison.data.cpu(),
                       'samples/sample_comp_iaf' +'.png', nrow=n)
        break

In [15]:
with torch.no_grad():
    z = torch.randn(64, 2)
    sample = vae.dec(z).view(64, 1, 28, 28)
    save_image(sample.data.cpu(), 'samples/sample_iaf.png', nrow=16)