In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torch.nn.functional as F

from vae_auto_encoder import VAEAutoEncoder

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
num_epochs = 200
batch_size = 64
learning_rate = 1e-3
r_loss_factor = 1000
decay_factor = 0.99
model_save_path = "./vae_digits_model.pth"

In [None]:
train_data = datasets.MNIST(root='./data/',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

In [None]:
model = VAEAutoEncoder(num_layers=4,
                       encoder_channels=[1, 32, 64, 64, 64],
                       encoder_kernel_sizes=[3, 3, 3, 3],
                       encoder_strides=[1, 2, 2, 1],
                       decoder_channels=[64, 64, 64, 32, 1],
                       decoder_kernel_sizes=[3, 3, 3, 3],
                       decoder_strides=[1, 2, 2, 1],
                       linear_sizes=[3136, 2],
                       view_size=[-1, 64, 7, 7],
                       use_batch_norm=False,
                       use_dropout=False).to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))

print(model)

In [None]:
def vae_kl_loss(mu, log_var):
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), axis=1)
    return torch.mean(kl_loss)

r_criterion = nn.MSELoss()

Train

In [None]:
for epoch in range(num_epochs):
    running_loss = 0.0
    running_r_loss = 0.0
    running_kl_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs, mu, log_var = model(inputs)
            r_loss = r_loss_factor * r_criterion(outputs, inputs)
            kl_loss = vae_kl_loss(mu, log_var)
            loss = r_loss + kl_loss
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_r_loss += r_loss.item() * inputs.size(0)
        running_kl_loss += kl_loss.item() * inputs.size(0)
    
    scheduler.step()    
    epoch_loss = running_loss / len(train_data)
    epoch_r_loss = running_r_loss / len(train_data)
    epoch_kl_loss = running_kl_loss / len(train_data)
    print('Epoch {0:03d}\tLoss: {1:0.5f}\tr_loss: {2:0.5f}\tkl_loss: {3:0.5f}'.format(
        epoch, epoch_loss, epoch_r_loss, epoch_kl_loss))

    torch.save(model.state_dict(), model_save_path)
    