In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torch.nn.functional as F

from vae_auto_encoder import VAEAutoEncoder

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
num_epochs = 200
batch_size = 64
learning_rate = 1e-3
r_loss_factor = 1000
decay_factor = 0.99
model_save_path = "./vae_digits_model.pth"

In [4]:
train_data = datasets.MNIST(root='../data/',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

In [5]:
model = VAEAutoEncoder(num_layers=4,
                       encoder_channels=[1, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[1, 2, 2, 1],
                       decoder_channels=[64, 64, 64, 32, 1],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[1, 2, 2, 1],
                       linear_sizes=[3136, 2],
                       view_size=[-1, 64, 7, 7],
                       use_batch_norm=False,
                       use_dropout=False).to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))

print(model)

VAEAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.01, inplace=True)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.01, inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
  )
  (mu_layer): Linear(in_features=3136, out_features=2, bias=True)
  (log_var_layer): Linear(in_features=3136, out_features=2, bias=True)
  (linear): Linear(in_features=2, out_features=3136, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_

In [6]:
def vae_kl_loss(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), axis=1)
    return torch.mean(kl_loss)

r_criterion = nn.MSELoss()

Train

In [7]:
for epoch in range(num_epochs):
    running_loss = 0.0
    running_r_loss = 0.0
    running_kl_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs, mu, log_var = model(inputs)
            r_loss = r_loss_factor * r_criterion(outputs, inputs)
            kl_loss = vae_kl_loss(mu, log_var)
            loss = r_loss + kl_loss
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_r_loss += r_loss.item() * inputs.size(0)
        running_kl_loss += kl_loss.item() * inputs.size(0)
    
    scheduler.step()    
    epoch_loss = running_loss / len(train_data)
    epoch_r_loss = running_r_loss / len(train_data)
    epoch_kl_loss = running_kl_loss / len(train_data)
    print('Epoch {0:03d}\tLoss: {1:0.5f}\tr_loss: {2:0.5f}\tkl_loss: {3:0.5f}'.format(
        epoch, epoch_loss, epoch_r_loss, epoch_kl_loss))

    torch.save(model.state_dict(), model_save_path)
    

Epoch 000	Loss: 56.65993	r_loss: 53.36145	kl_loss: 3.29848
Epoch 001	Loss: 49.31674	r_loss: 44.86008	kl_loss: 4.45666
Epoch 002	Loss: 47.59793	r_loss: 42.84993	kl_loss: 4.74800
Epoch 003	Loss: 46.70383	r_loss: 41.78616	kl_loss: 4.91766
Epoch 004	Loss: 46.00426	r_loss: 40.95599	kl_loss: 5.04827
Epoch 005	Loss: 45.61578	r_loss: 40.46992	kl_loss: 5.14586
Epoch 006	Loss: 45.21184	r_loss: 39.98462	kl_loss: 5.22722
Epoch 007	Loss: 44.92254	r_loss: 39.64631	kl_loss: 5.27623
Epoch 008	Loss: 44.63723	r_loss: 39.32738	kl_loss: 5.30985
Epoch 009	Loss: 44.38391	r_loss: 39.02800	kl_loss: 5.35591
Epoch 010	Loss: 44.24745	r_loss: 38.85204	kl_loss: 5.39541
Epoch 011	Loss: 44.03714	r_loss: 38.61011	kl_loss: 5.42703
Epoch 012	Loss: 43.88689	r_loss: 38.41378	kl_loss: 5.47310
Epoch 013	Loss: 43.69818	r_loss: 38.20751	kl_loss: 5.49067
Epoch 014	Loss: 43.60069	r_loss: 38.09216	kl_loss: 5.50853
Epoch 015	Loss: 43.44408	r_loss: 37.91687	kl_loss: 5.52721
Epoch 016	Loss: 43.34345	r_loss: 37.78756	kl_loss: 5.555