In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Mnist Train_dataset
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./mnist_data', train=True, transform=transform, download=True)
train_dataset = datasets.MNIST(root='./mnist_data', train=False, transform=transform, download=True)

BATCH_SIZE = 256
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

In [3]:
class VAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim1=512, hidden_dim2=256, z_dim=2):
        super(VAE, self).__init__()

        # encoder
        self.en_fc1 = nn.Linear(input_dim, hidden_dim1)
        self.en_fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.en_fc3_u = nn.Linear(hidden_dim2, z_dim) # u, mean
        self.en_fc3_var = nn.Linear(hidden_dim2, z_dim) # s, log_var

        # decoder

        self.de_fc1 = nn.Linear(z_dim, hidden_dim2)
        self.de_fc2 = nn.Linear(hidden_dim2, hidden_dim1)
        self.de_fc3 = nn.Linear(hidden_dim1, input_dim)

        #
        self.flatten = nn.Flatten()

    def encoder(self, inputs):
        x = self.flatten(inputs)
        x = F.relu(self.en_fc1(x))
        x = F.relu(self.en_fc2(x))

        mu = self.en_fc3_u(x)
        log_var = self.en_fc3_var(x)

        return mu, log_var
    
    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + std * eps
        return z
    
    def decoder(self, z):
        x = F.relu(self.de_fc1(z))
        x = F.relu(self.de_fc2(x))
        recon_x = torch.sigmoid(self.de_fc3(x)) # 0 ~ 1
        return recon_x
    
    def forward(self, inputs):
        mu, log_var = self.encoder(inputs)
        z = self.reparameterization(mu, log_var)
        recon_x = self.decoder(z)
        return recon_x , mu, z
    
vae = VAE()
vae

VAE(
  (en_fc1): Linear(in_features=784, out_features=512, bias=True)
  (en_fc2): Linear(in_features=512, out_features=256, bias=True)
  (en_fc3_u): Linear(in_features=256, out_features=2, bias=True)
  (en_fc3_var): Linear(in_features=256, out_features=2, bias=True)
  (de_fc1): Linear(in_features=2, out_features=256, bias=True)
  (de_fc2): Linear(in_features=256, out_features=512, bias=True)
  (de_fc3): Linear(in_features=512, out_features=784, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [4]:
optimizer = torch.optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    flatten = nn.Flatten()
    bce = nn.BCELoss(reduction='sum')

    x = flatten(x)

    BCE_Loss = bce(recon_x, x)
    KLD_Loss = 0.5 * torch.sum(mu**2 + torch.exp(log_var) - log_var -1)
    return BCE_Loss + KLD_Loss

In [5]:
def train(epoch):
    vae.train()
    train_losses = 0.0
    for data, _ in train_dataloader:
        optimizer.zero_grad()
        recon_batch, mu, log_var = vae(data)

        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        optimizer.step()
        train_losses += loss.item()
    print(f'>>>>>EPOCH {epoch} Average_loss = {train_losses / len(train_dataloader.dataset)}')


def test():
    vae.eval()
    test_losses = 0.0
    with torch.no_grad():
        for data, _ in test_dataloader:
            recon_batch, mu, log_var = vae(data)
            test_losses += loss_function(recon_batch, data, mu, log_var).item()
    test_losses /= len(test_dataloader.dataset)
    print(f'>>>>> Test set loss {test_losses}')

In [6]:
for epoch in range(1, 21):
    train(epoch)
    test()

>>>>>EPOCH 1 Average_loss = 265.238412109375
>>>>> Test set loss 205.24559291992188
>>>>>EPOCH 2 Average_loss = 196.314794921875
>>>>> Test set loss 187.6227565185547
>>>>>EPOCH 3 Average_loss = 180.729150390625
>>>>> Test set loss 175.14065998535156
>>>>>EPOCH 4 Average_loss = 172.622936328125
>>>>> Test set loss 170.7321541015625
>>>>>EPOCH 5 Average_loss = 169.139123046875
>>>>> Test set loss 167.825475
>>>>>EPOCH 6 Average_loss = 166.710246484375
>>>>> Test set loss 165.7691666748047
>>>>>EPOCH 7 Average_loss = 164.917824609375
>>>>> Test set loss 164.0526201171875
>>>>>EPOCH 8 Average_loss = 163.07696640625
>>>>> Test set loss 162.55447817382813
>>>>>EPOCH 9 Average_loss = 161.50196015625
>>>>> Test set loss 160.61736188964844
>>>>>EPOCH 10 Average_loss = 159.830196484375
>>>>> Test set loss 158.98301838378907
>>>>>EPOCH 11 Average_loss = 158.2415953125
>>>>> Test set loss 157.53611528320312
>>>>>EPOCH 12 Average_loss = 157.094649609375
>>>>> Test set loss 156.1448741455078
>>>>>E

In [7]:
from torchvision.utils import save_image

with torch.no_grad():
    z = torch.randn(16, 2)

    sample = vae.decoder(z)
    save_image(sample.view(16, 1, 28, 28), 'tutorial_vae_result.png')