In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable

In [2]:
BATCH_SIZE = 100
TEMPERATURE = 1.0
TEMP_MIN = 0.5
ANNEAL_RATE = 0.00003

In [3]:
transform=transforms.Compose([
        transforms.ToTensor(),
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
eval_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
class Gumbel_Softmax_VAE(nn.Module):
    def __init__(self, x_dim, h1, h2, z_dim, c_dim):
        super(Gumbel_Softmax_VAE, self).__init__()
        
        self.z_dim = z_dim
        self.c_dim = c_dim
        
        self.enc = nn.Sequential(
            nn.Linear(x_dim, h1),
            nn.ReLU(),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, z_dim*c_dim)
        )
        
        self.dec = nn.Sequential(
            nn.Linear(z_dim*c_dim, h2),
            nn.ReLU(),
            nn.Linear(h2, h1),
            nn.ReLU(),
            nn.Linear(h1, x_dim),
            nn.Sigmoid()
            )
        


    def forward(self, x, temperature):
        q_y = self.enc(x).view(x.size(0), self.z_dim, self.c_dim)
        return self.dec(self.gumbel_softmax(q_y, temperature)), F.softmax(q_y, dim=-1).view(x.size(0), -1)
    
    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)
    
    def gumbel_softmax(self, logits, temperature):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax( y / temperature, dim=-1).view(y.size(0), -1)
        

vae = Gumbel_Softmax_VAE(x_dim=784, h1=512, h2=256, z_dim=30, c_dim=10)
if torch.cuda.is_available():
    vae.cuda()

In [5]:
vae

Gumbel_Softmax_VAE(
  (enc): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=300, bias=True)
  )
  (dec): Sequential(
    (0): Linear(in_features=300, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=784, bias=True)
    (5): Sigmoid()
  )
)

## Loss Function 

ELBO :

$ logp_\theta(x) \ge E_{q_\phi(y|x)}[logp_\theta(x|y)] - KL[q_\phi(y|x) \| p_\theta(y)] $

KL : 

$ KLD = \sum_{y} q_\phi(y|x) [ log{\frac{q_\phi(y|x)}{p_\theta(y)}} ]$

In [6]:
def loss_fn(recon_x, x, qy):
    # reconstruction loss : binary cross entropy 
    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # kl divergence 
    
    log_qy = torch.log(qy+1e-20)
    log_py = Variable(torch.log(torch.Tensor([1.0/10])))
    kld_loss = torch.sum(qy*(log_qy - log_py),dim=-1).mean()

    return bce_loss + kld_loss

In [7]:
optimizer = torch.optim.Adam(vae.parameters())

In [8]:
def train(epoch):
    vae.train()
    train_loss = 0
    temp = TEMPERATURE
    for batch_ind, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.view(BATCH_SIZE, -1)
        
        if batch_ind % 100 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_ind), TEMP_MIN)

        recon_x, qy = vae(data, temp)
        loss = loss_fn(recon_x, data, qy)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_ind % 200 == 0:
            print('Train Epoch:{} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_ind*len(data), len(train_loader.dataset),
                100*batch_ind/len(train_loader), loss.item()/len(data)))
        
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss/len(train_loader.dataset)))

In [9]:
def evaluation(): 
    vae.eval()
    eval_loss = 0
    temp = TEMPERATURE
    with torch.no_grad():
        for batch_ind, (data, _) in enumerate(eval_loader):
#             data = data.cuda()
            if batch_ind % 100 == 1:
                temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_ind), TEMP_MIN)
            data = data.view(BATCH_SIZE, -1)
            recon, qy = vae(data, temp)
            eval_loss += loss_fn(recon, data, qy)
    eval_loss /= len(eval_loader.dataset)
    print('====> Evaluation loss : {:.4f}'.format(eval_loss))    

In [10]:
for epoch in range(1, 100):
    train(epoch)
    evaluation()

====> Epoch: 1 Average loss: 171.0613
====> Evaluation loss : 129.7255
====> Epoch: 2 Average loss: 117.7736
====> Evaluation loss : 106.4475
====> Epoch: 3 Average loss: 101.9621
====> Evaluation loss : 96.4246
====> Epoch: 4 Average loss: 93.9373
====> Evaluation loss : 90.8421
====> Epoch: 5 Average loss: 89.4157
====> Evaluation loss : 86.8703
====> Epoch: 6 Average loss: 86.2640
====> Evaluation loss : 84.7581
====> Epoch: 7 Average loss: 84.1200
====> Evaluation loss : 82.9415
====> Epoch: 8 Average loss: 82.4153
====> Evaluation loss : 82.0957
====> Epoch: 9 Average loss: 81.0383
====> Evaluation loss : 80.9214
====> Epoch: 10 Average loss: 79.9298
====> Evaluation loss : 79.4604
====> Epoch: 11 Average loss: 78.9423
====> Evaluation loss : 78.8956
====> Epoch: 12 Average loss: 78.1631
====> Evaluation loss : 78.4297
====> Epoch: 13 Average loss: 77.4255
====> Evaluation loss : 77.7390
====> Epoch: 14 Average loss: 76.8149
====> Evaluation loss : 76.8403
====> Epoch: 15 Average 

====> Epoch: 35 Average loss: 70.8938
====> Evaluation loss : 73.3783
====> Epoch: 36 Average loss: 70.7463
====> Evaluation loss : 73.2114
====> Epoch: 37 Average loss: 70.6186
====> Evaluation loss : 73.1550
====> Epoch: 38 Average loss: 70.4653
====> Evaluation loss : 73.2979
====> Epoch: 39 Average loss: 70.3539
====> Evaluation loss : 72.9600
====> Epoch: 40 Average loss: 70.2164
====> Evaluation loss : 72.9778
====> Epoch: 41 Average loss: 70.1154
====> Evaluation loss : 72.8712
====> Epoch: 42 Average loss: 69.9755
====> Evaluation loss : 72.7372
====> Epoch: 43 Average loss: 69.8716
====> Evaluation loss : 73.0467
====> Epoch: 44 Average loss: 69.8783
====> Evaluation loss : 72.7722
====> Epoch: 45 Average loss: 69.7321
====> Evaluation loss : 72.6770
====> Epoch: 46 Average loss: 69.6513
====> Evaluation loss : 72.7095
====> Epoch: 47 Average loss: 69.5431
====> Evaluation loss : 72.5913
====> Epoch: 48 Average loss: 69.4323
====> Evaluation loss : 72.6106
====> Epoch: 49 Aver

====> Evaluation loss : 72.2393
====> Epoch: 70 Average loss: 68.0661
====> Evaluation loss : 72.1938
====> Epoch: 71 Average loss: 68.0201
====> Evaluation loss : 72.2430
====> Epoch: 72 Average loss: 67.9541
====> Evaluation loss : 72.3576
====> Epoch: 73 Average loss: 67.9022
====> Evaluation loss : 72.2727
====> Epoch: 74 Average loss: 67.8803
====> Evaluation loss : 72.2026
====> Epoch: 75 Average loss: 67.8533
====> Evaluation loss : 72.2098
====> Epoch: 76 Average loss: 67.7852
====> Evaluation loss : 72.4217
====> Epoch: 77 Average loss: 67.7538
====> Evaluation loss : 72.3296
====> Epoch: 78 Average loss: 67.6881
====> Evaluation loss : 72.2687
====> Epoch: 79 Average loss: 67.7069
====> Evaluation loss : 72.1303
====> Epoch: 80 Average loss: 67.6442
====> Evaluation loss : 72.0282
====> Epoch: 81 Average loss: 67.6399
====> Evaluation loss : 72.3879
====> Epoch: 82 Average loss: 67.6104
====> Evaluation loss : 72.2246
====> Epoch: 83 Average loss: 67.5292
====> Evaluation los

In [12]:
with torch.no_grad():
    for batch_ind, (data, _) in enumerate(eval_loader):
#             data = data.cuda()
        
        data = data.view(BATCH_SIZE, -1)
        recon, qy = vae(data, TEMPERATURE)
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n].view(BATCH_SIZE, 28, 28)[:n],
                                    recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
        save_image(comparison.data.cpu(),
                       'samples/sample_gumbel_softmax' +'.png', nrow=n)
        break
        

RuntimeError: shape '[100, 1, 28, 28]' is invalid for input of size 6272