In [11]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchvision.utils import save_image

In [12]:
BATCH_SIZE=100

In [13]:
transform=transforms.Compose([
        transforms.ToTensor(),
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
eval_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [14]:
class VAE(nn.Module):
    def __init__(self, x_dim, h1, h2, z_dim):
        super(VAE, self).__init__()
        
        self.enc = nn.Sequential(
            nn.Linear(x_dim, h1),
            nn.ReLU(),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, z_dim*2)
        )
        
        self.dec = nn.Sequential(
            nn.Linear(z_dim, h2),
            nn.ReLU(),
            nn.Linear(h2, h1),
            nn.ReLU(),
            nn.Linear(h1, x_dim),
            nn.Sigmoid()
            )
    
    def forward(self, x):
        mu, log_var = torch.chunk(self.enc(x), 2, dim=-1)
        z = self.sampling(mu, log_var)
        return self.dec(z), mu, log_var
        
    def sampling(self, mu, log_var):
        # reparametrization trick
        std = torch.exp(0.5*log_var)
        eps = torch.rand_like(std)
        return mu + (eps * std)

vae = VAE(x_dim=784, h1=512, h2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

In [15]:
vae

VAE(
  (enc): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=4, bias=True)
  )
  (dec): Sequential(
    (0): Linear(in_features=2, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=784, bias=True)
    (5): Sigmoid()
  )
)

In [21]:
def loss_fn(recon_x, x, mu, log_var):
    # reconstruction loss : binary cross entropy 
    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # kl divergence 
    kld_loss = 0.5 * torch.sum(torch.exp(log_var) + mu.pow(2) -1 - log_var)
    return bce_loss + kld_loss

In [17]:
optimizer = torch.optim.Adam(vae.parameters())

In [18]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_ind, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.view(BATCH_SIZE, -1)
        recon_x, mu, log_var = vae(data)
        loss = loss_fn(recon_x, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_ind % 200 == 0:
            print('Train Epoch:{} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_ind*len(data), len(train_loader.dataset),
                100*batch_ind/len(train_loader), loss.item()/len(data)))
        
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss/len(train_loader.dataset)))

In [19]:
def evaluation(): 
    vae.eval()
    eval_loss = 0
    with torch.no_grad():
        for data, _ in eval_loader:
#             data = data.cuda()
            data = data.view(BATCH_SIZE, -1)
            recon, mu, log_var = vae(data)
            eval_loss += loss_fn(recon, data, mu, log_var)
    eval_loss /= len(eval_loader.dataset)
    print('====> Evaluation loss : {:.4f}'.format(eval_loss))    

In [20]:
for epoch in range(1, 10):
    train(epoch)
    evaluation()



====> Epoch: 1 Average loss: 174.6615
====> Evaluation loss : 157.3062
====> Epoch: 2 Average loss: 153.3949
====> Evaluation loss : 150.6160
====> Epoch: 3 Average loss: 148.5896
====> Evaluation loss : 146.7084
====> Epoch: 4 Average loss: 145.8419
====> Evaluation loss : 144.9691
====> Epoch: 5 Average loss: 144.0589
====> Evaluation loss : 143.7350
====> Epoch: 6 Average loss: 142.7403
====> Evaluation loss : 142.8901
====> Epoch: 7 Average loss: 141.6540
====> Evaluation loss : 141.8195
====> Epoch: 8 Average loss: 140.6866
====> Evaluation loss : 140.7355
====> Epoch: 9 Average loss: 140.0557
====> Evaluation loss : 140.7283


In [25]:
with torch.no_grad():
    z = torch.randn(64, 2)
    sample = vae.dec(z)

In [26]:
sample.view(64, 1, 28, 28)

tensor([[[[4.4138e-08, 3.0656e-08, 4.7158e-08,  ..., 5.6960e-08,
           4.3091e-08, 5.1339e-08],
          [1.0038e-07, 8.6859e-08, 5.5758e-08,  ..., 3.8091e-08,
           2.8813e-08, 1.9932e-08],
          [4.9212e-08, 7.8227e-08, 5.3864e-05,  ..., 7.4444e-06,
           5.0379e-08, 9.6977e-08],
          ...,
          [3.2870e-08, 6.3056e-08, 4.6359e-10,  ..., 1.9546e-07,
           1.6134e-07, 3.1015e-08],
          [7.1825e-08, 4.9877e-08, 8.9575e-08,  ..., 8.3595e-08,
           3.6791e-08, 2.9302e-08],
          [4.3021e-08, 5.2267e-08, 5.0142e-08,  ..., 3.4810e-08,
           5.1181e-08, 3.9937e-08]]],


        [[[6.9971e-09, 5.6664e-09, 6.5575e-09,  ..., 5.3638e-09,
           1.0317e-08, 5.0095e-09],
          [8.2636e-09, 1.4489e-08, 1.0991e-08,  ..., 5.2961e-09,
           5.0652e-09, 6.6957e-09],
          [6.1588e-09, 8.7248e-09, 3.3624e-09,  ..., 1.1742e-08,
           1.9984e-08, 9.5720e-09],
          ...,
          [3.9356e-09, 9.6632e-09, 1.1686e-05,  ..., 1.12