In [1]:
# < define classes to build encoder and decoder networks >
# We use encoder and decoder architectures in MNIST experiments in [1] (in C.1).
# Other implementation details are adjusted to simplify the experiment.
# [1] Tolstikhin, Ilya, Olivier Bousquet, Sylvain Gelly, and Bernhard Schoelkopf. "Wasserstein auto-encoders."
#    arXiv preprint arXiv:1711.01558 (2017).

import itertools
import numpy as np
import random
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, dim_x, dim_z, nf):
        super(Encoder, self).__init__()
        self.dim_z = dim_z
        nc, resolution, _ = dim_x

        # input size: nc x 28 x 28
        self.c1 = nn.Sequential(
                nn.Conv2d(nc, nf, 4, 2, 1),
                nn.BatchNorm2d(nf),
                nn.ReLU(inplace=True),
                )
        # input size: nf x 14 x 14
        self.c2 = nn.Sequential(
                nn.Conv2d(nf, nf*2, 4, 2, 1),
                nn.BatchNorm2d(nf*2),
                nn.ReLU(inplace=True),
                )
        # input size: (nf*2) x 7 x 7
        self.c3 = nn.Sequential(
                nn.Conv2d(nf*2, nf*4, 4, 2, 2),
                nn.BatchNorm2d(nf*4),
                nn.ReLU(inplace=True),
                )
        # input size: (nf*4) x 4 x 4
        self.c4 = nn.Sequential(
                nn.Conv2d(nf*4, nf*8, 4, 2, 1),
                nn.BatchNorm2d(nf*8),
                nn.ReLU(inplace=True),
                )
        # input size: (nf*8) x 2 x 2
        self.mu_net = nn.Conv2d(nf*8, dim_z, 2, 2, 0)
        self.log_var_net = nn.Conv2d(nf*8, dim_z, 2, 2, 0)

    def forward(self, x_input):
        h1 = self.c1(x_input)
        h2 = self.c2(h1)
        h3 = self.c3(h2)
        h4 = self.c4(h3)
        mu, log_var = self.mu_net(h4), self.log_var_net(h4)
        return mu.view(-1, self.dim_z), log_var.view(-1, self.dim_z)

class Decoder(nn.Module):
    def __init__(self, dim_z, dim_x, nf, final_activation):
        super(Decoder, self).__init__()
        self.dim_z = dim_z
        nc, h, _ = dim_x

        # input size: dim_z x 1 x 1
        self.upc1 = nn.Sequential(
                  nn.ConvTranspose2d(dim_z, nf*8, 7, 1, 0),
                  )
        # input size: (nf*8) x 7 x 7
        self.upc2 = nn.Sequential(
                  nn.ConvTranspose2d(nf*8, nf*4, 4, 2, 1),
                  nn.BatchNorm2d(nf*4),
                  nn.ReLU(inplace=True),
                  )
        # input size: (nf*4) x 14 x 14
        self.upc3 = nn.Sequential(
                  nn.ConvTranspose2d(nf*4, nf*2, 4, 2, 1),
                  nn.BatchNorm2d(nf*2),
                  nn.ReLU(inplace=True),
                  )
        # input size: (nf*2) x 28 x 28
        if final_activation == 'sigmoid':
            self.mu_net = nn.Sequential(nn.ConvTranspose2d(nf*2, nc, 3, 1, 1),
                                        nn.Sigmoid())
        elif final_activation == 'tanh':
            self.mu_net = nn.Sequential(nn.ConvTranspose2d(nf*2, nc, 3, 1, 1),
                                        nn.Tanh())
        elif final_activation == 'none':
            self.mu_net = nn.Sequential(nn.ConvTranspose2d(nf*2, nc, 3, 1, 1))
        else:
            print('available choices for final_activation: sigmoid, tanh, and none')

        self.obs_log_var_net = nn.ConvTranspose2d(1, nc, h, 1, 0)

    def forward(self, z_input):
        d1 = self.upc1(z_input.view(-1, self.dim_z, 1, 1))
        d2 = self.upc2(d1)
        d3 = self.upc3(d2)
        o = self.mu_net(d3)

        device = z_input.get_device()
        if device == -1:
            one_tensor = torch.ones((1,1,1,1))
        else:
            one_tensor = torch.ones((1,1,1,1)).cuda()

        obs_log_var = self.obs_log_var_net(one_tensor)
        return o, obs_log_var

In [None]:
# < define basic utility functions >
def sampling(mean, log_var):
    device = mean.get_device()
    if device == -1:
        epsilon = torch.randn(mean.shape)
        return mean + torch.exp(0.5 * log_var) * epsilon
    else:
        epsilon = torch.randn(mean.shape).cuda()
        return mean + torch.exp(0.5 * log_var).cuda() * epsilon

def kl_criterion(mu1, log_var1, mu2, log_var2):
    sigma1 = log_var1.mul(0.5).exp()
    sigma2 = log_var2.mul(0.5).exp()
    kld = torch.log(sigma2/sigma1) + ((sigma1**2) + (mu1 - mu2)**2)/(2.0*(sigma2**2)) - 0.5
    return torch.sum(kld, dim=-1)

def kl_annealing_weight(epoch, total_epochs):
    return min(1, epoch / (0.1*total_epochs))

In [None]:
# < define advanced utility functions >
def extract_feature(result_path, x, mean_extract=True):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    saved_model = torch.load('%s/model.pth' % result_path, weights_only=False)
    encoder, decoder = saved_model['encoder'], saved_model['decoder']
    encoder.eval(); decoder.eval()

    if device == 'cuda':
        z_mean, z_log_var = encoder(x.cuda())
    elif device == 'cpu':
        z_mean, z_log_var = encoder(x)
    z_sample = sampling(z_mean, z_log_var)
    if mean_extract:
        return z_mean
    else:
        return z_sample

def reconstruct(result_path, x, mean_extract=True):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    saved_model = torch.load('%s/model.pth' % result_path, weights_only=False)
    encoder, decoder = saved_model['encoder'], saved_model['decoder']
    encoder.eval(); decoder.eval()

    if device == 'cuda':
        z_mean, z_log_var = encoder(x.cuda())
    elif device == 'cpu':
        z_mean, z_log_var = encoder(x)
    z_sample = sampling(z_mean, z_log_var)
    if mean_extract:
        x_recon = decoder(z_mean)[0]
    else:
        x_recon = decoder(z_sample)[0]
    return x_recon

def generate(result_path, z_sample):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    saved_model = torch.load('%s/model.pth' % result_path, weights_only=False)
    decoder = saved_model['decoder']
    decoder.eval()

    if device == 'cuda':
        x_gen = decoder(z_sample.cuda())[0]
    elif device == 'cpu':
        x_gen = decoder(z_sample)[0]
    
    return x_gen

def plot_train_reconstruction(result_path, epoch, n_vis):
    # extract features with trained VAE networks
    z_train = extract_feature(result_path, torch.tensor(x_train[:n_vis], dtype=torch.float32))
    z_train = z_train.detach().cpu().numpy()
    
    # compute reconstruction results
    x_train_recon = reconstruct(result_path, torch.tensor(x_train[:n_vis], dtype=torch.float32))
    x_train_recon = x_train_recon.cpu().detach().numpy()
    
    # plot and save reconstruction results
    fig, axes = plt.subplots(2, n_vis, figsize=(12, 12))
    for i in range(n_vis):
        axes[0, i].imshow(x_train[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
    
        axes[1, i].imshow(x_train_recon[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
        del(i)
    plt.tight_layout()
    plt.savefig('%s/train_reconstruction_results_%d.pdf' % (result_path, epoch), dpi=600)
    plt.close()
    
    return None

def plot_test_reconstruction(result_path, epoch, n_vis):
    # extract features with trained VAE networks
    z_test = extract_feature(result_path, torch.tensor(x_test[:n_vis], dtype=torch.float32))
    z_test = z_test.detach().cpu().numpy()
    
    # compute reconstruction results
    x_test_recon = reconstruct(result_path, torch.tensor(x_test[:n_vis], dtype=torch.float32))
    x_test_recon = x_test_recon.cpu().detach().numpy()
    
    # plot and save reconstruction results
    fig, axes = plt.subplots(2, n_vis, figsize=(12, 12))
    for i in range(n_vis):
        axes[0, i].imshow(x_test[i].reshape(28, 28), cmap='gray')
        axes[0, i].axis('off')
    
        axes[1, i].imshow(x_test_recon[i].reshape(28, 28), cmap='gray')
        axes[1, i].axis('off')
        del(i)
    plt.tight_layout()
    plt.savefig('%s/test_reconstruction_results_%d.pdf' % (result_path, epoch), dpi=600)
    plt.close()
    
    return None

def plot_generation(result_path, epoch, n_row, n_col):
    n_row, n_col = 10, 10
    x_gen = generate(result_path, torch.tensor(torch.randn(n_row*n_col, dim_z), dtype=torch.float32))
    x_gen = x_gen.cpu().detach().numpy()
    
    fig, axes = plt.subplots(n_row, n_col, figsize=(12, 12))
    for i in range(n_row):
        for j in range(n_col):
            axes[i, j].imshow(x_gen[i*n_row+j].reshape(28, 28), cmap='gray')
            axes[i, j].axis('off')
        del(i, j)
    plt.tight_layout()
    plt.savefig('%s/generation_results_%d.pdf' % (result_path, epoch), dpi=600)
    plt.close()
    
    return None

In [None]:
import numpy as np
import datetime
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import random
import sys
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms
from sklearn.preprocessing import OneHotEncoder
import progressbar
import tqdm

def model(dim_x, dim_z, nf, decoder_final_activation='none'):
    '''
    dim_z: dimension of representations
    nf: a factor to control the overall filter sizes in encoder and decoder networks
    decoder_final_activation: the last activation layer in decoder.
    '''

    encoder = Encoder(dim_x, dim_z, nf)
    decoder = Decoder(dim_z, dim_x, nf, final_activation=decoder_final_activation)
    return [encoder, decoder]

def fit(model, x_train, x_val, x_test,
        num_epoch, batch_size, num_worker, seed,
        beta_recon, beta_kl, Adam_beta1, Adam_beta2, weight_decay,
        init_lr, lr_milestones, lr_gamma, val_period, recon_error,
        dtype, result_path):
    '''
    num_epoch: the number of epoch
    batch_size: the number of samples in each mini-batch
    num_worker: the number of CPU cores
    seed: the random seed number
    beta_recon: the coefficient of the log-likelihood
    beta_kl: the coefficient of the KL-penalty term in ELBOs
    Adam_beta1: beta1 for Adam optimizer
    Adam_beta2: beta2 for Adam optimizer
    weight_decay: the coefficient of the half of L2 penalty term
    init_lr: the initial learning rate
    lr_milestones: the epochs to reduce the learning rate
    lr_gamma: the multiplier for each time learning rate is reduced
    val_period: the frequency of evaluating trained models
    recon_error: the type of reconstruction error (supports 'mse' and 'likelihood')
    dtype: the data type
    result_path: the directory where results are saved. the directory where results are saved. Its default value is results/vae-time=month-day-hour-min-sec.
    '''

    # declare basic variables
    encoder, decoder = model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    Adam_betas = (Adam_beta1, Adam_beta2)
    if result_path is None:
        now = datetime.datetime.now()
        result_path = './results/vae-time=%d-%d-%d-%d-%d' % (now.month, now.day, now.hour, now.minute, now.second)
    os.makedirs(result_path, exist_ok=True)

    # lines for reproducibility
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # convert data to tensors
    x_train = torch.tensor(x_train, dtype=dtype)
    x_val = torch.tensor(x_val, dtype=dtype)
    x_test = torch.tensor(x_test, dtype=dtype)

    # define optimizers and schedulers
    enc_optimizer = torch.optim.Adam(encoder.parameters(), betas=Adam_betas, lr=init_lr, weight_decay=weight_decay)
    gen_optimizer = torch.optim.Adam(decoder.parameters(), betas=Adam_betas, lr=init_lr, weight_decay=weight_decay)

    enc_scheduler = torch.optim.lr_scheduler.MultiStepLR(enc_optimizer, milestones=lr_milestones, gamma=lr_gamma)
    gen_scheduler = torch.optim.lr_scheduler.MultiStepLR(gen_optimizer, milestones=lr_milestones, gamma=lr_gamma)

    # define training log
    loss_names = ['loss', 'recon_loss', 'kl_post_prior', 'l2_penalty']
    logs = {}
    for datasetname in ['train', 'val', 'test']:
        logs[datasetname] = {}
        for loss_name in loss_names:
            logs[datasetname][loss_name] = []
        del(loss_name)
    del(datasetname)
    summary_stats = []

    # define data loader
    dataloader = {}
    dataloader['train'] = DataLoader(x_train, batch_size=batch_size, num_workers=num_worker, shuffle=True)
    dataloader['val'] = DataLoader(x_val, batch_size=batch_size, num_workers=num_worker)
    dataloader['test'] = DataLoader(x_test, batch_size=batch_size, num_workers=num_worker)

    # training part
    mse_criterion = torch.nn.MSELoss()
    if device == 'cuda':
        encoder.cuda()
        decoder.cuda()
        mse_criterion.cuda()

    best_val_epoch=1; best_val_loss = 10**5
    for epoch in range(1, num_epoch+1):
        num_batch = 0
        for x_batch in tqdm.tqdm(dataloader['train'], desc='[Epoch %d/%d] Training' % (epoch, num_epoch)):
            num_batch += 1
            if device == 'cuda':
                x_batch = x_batch.cuda()

            encoder.train()
            decoder.train()

            enc_optimizer.zero_grad()
            gen_optimizer.zero_grad()

            # forward step
            z_mean, z_log_var = encoder(x_batch)
            z_sample = sampling(z_mean, z_log_var)
            if device == 'cuda':
                z_sample = z_sample.cuda()
            fire_rate, obs_log_var = decoder(z_sample)
            
            # compute objective function
            kl_weight = kl_annealing_weight(epoch, num_epoch)
            if recon_error=='mse':
                obs_loglik = -0.5*torch.sum((fire_rate - x_batch)**2, dim=(1, 2, 3))
            elif recon_error=='likelihood':
                obs_loglik = torch.sum(-0.5 * (obs_log_var + (x_batch - fire_rate)**2 / torch.exp(obs_log_var)), dim=(1, 2, 3))
            kl_post_prior = kl_criterion(z_mean, z_log_var, torch.zeros_like(z_mean), torch.zeros_like(z_log_var))

            elbo = beta_recon*obs_loglik - beta_kl*kl_weight*kl_post_prior
            loss = torch.mean(-elbo)

            # backward step
            loss.backward()

            enc_optimizer.step()
            gen_optimizer.step()
        del(x_batch)

        encoder.eval()
        decoder.eval()
        for datasetname in ['train', 'val', 'test']:
            loss_cumsum, sample_size = 0.0, 0
            obs_loglik_cumsum, kl_post_prior_cumsum = 0.0, 0.0
            for x_batch in tqdm.tqdm(dataloader[datasetname],
                                     desc='[Epoch %d/%d] Computing loss terms on %s' % (epoch, num_epoch, datasetname)):
                x_batch = x_batch.cuda() if device == 'cuda' else x_batch

                # forward step
                z_mean, z_log_var = encoder(x_batch)
                z_sample = sampling(z_mean, z_log_var)
                if device == 'cuda':
                    z_sample = z_sample.cuda()
                fire_rate, obs_log_var = decoder(z_sample)

                # compute objective function
                if recon_error=='mse':
                    obs_loglik = -0.5*torch.sum((fire_rate - x_batch)**2, dim=(1, 2, 3))
                elif recon_error=='likelihood':
                    obs_loglik = torch.sum(-0.5 * (obs_log_var + (x_batch - fire_rate)**2 / torch.exp(obs_log_var)), dim=(1, 2, 3))
                kl_post_prior = kl_criterion(z_mean, z_log_var, torch.zeros_like(z_mean), torch.zeros_like(z_log_var))

                elbo = beta_recon*obs_loglik - beta_kl*kl_post_prior
                loss = torch.mean(-elbo)

                loss_cumsum += loss.item()*np.shape(x_batch)[0]
                obs_loglik_cumsum += torch.mean(obs_loglik).item()*np.shape(x_batch)[0]
                kl_post_prior_cumsum += torch.mean(kl_post_prior).item()*np.shape(x_batch)[0]
                sample_size += np.shape(x_batch)[0]
            del(x_batch)

            l2_penalty = 0.0
            for networks in [encoder, decoder]:
                for name, m in networks.named_parameters():
                    if 'weight' in name:
                        l2_penalty += 0.5*torch.sum(m**2)

            logs[datasetname]['loss'].append(loss_cumsum/sample_size)
            logs[datasetname]['recon_loss'].append(-obs_loglik_cumsum/sample_size)
            logs[datasetname]['kl_post_prior'].append(kl_post_prior_cumsum/sample_size)
            logs[datasetname]['l2_penalty'].append(l2_penalty.item())

        if epoch % val_period == 0:
            # save loss curves
            plt.figure()
            linestyles = ['solid', 'dashed', 'dotted']
            i = 0
            for dataset_name in ['train', 'val', 'test']:
                plt.plot(logs[dataset_name]['loss'][:], linestyle=linestyles[i],
                         label=dataset_name)
                i += 1
            del(i)
            plt.legend()
            plt.xlabel('epoch')
            plt.ylabel('loss')
            plt.savefig('%s/loss_curves.pdf' % (result_path), dpi=600)
            plt.close()

            # update models and logs if the best validation loss is updated
            current_val_loss = logs['val']['loss'][-1]
            best_val_loss = np.minimum(best_val_loss, current_val_loss)
            if best_val_loss == current_val_loss:
                # update model and logs
                best_val_epoch = epoch
                os.makedirs('%s/' % result_path, exist_ok=True)
                torch.save({'encoder': encoder, 'decoder': decoder, 'logs': logs,
                            'num_epoch': num_epoch, 'batch_size': batch_size,
                            'num_worker': num_worker, 'seed': seed,
                            'beta_recon': beta_recon, 'beta_kl': beta_kl,
                            'Adam_beta1': Adam_beta1, 'Adam_beta2': Adam_beta2,
                            'weight_decay': weight_decay, 'init_lr': init_lr,
                            'lr_milestones': lr_milestones, 'lr_gamma': lr_gamma,
                            'val_period': val_period, 'recon_error': recon_error,
                            'dtype': dtype, 'result_path': result_path},
                            '%s/model.pth' % result_path)
                
                # draw and save reconstruction and generation results
                plot_train_reconstruction(result_path, epoch, n_vis=10)
                plot_test_reconstruction(result_path, epoch, n_vis=10)
                plot_generation(result_path, epoch, n_row=10, n_col=10)
                
            if epoch == num_epoch:
                # update logs
                saved_model = torch.load('%s/model.pth' % result_path, weights_only=False)
                saved_model['logs'] = logs
                torch.save(saved_model, '%s/model.pth' % result_path)
                del(saved_model)

            current_summary_stats_row = {}

            current_summary_stats_row['epoch'] = epoch
            current_summary_stats_row['best_val_epoch'] = best_val_epoch
            current_summary_stats_row['train_loss'] = logs['train']['loss'][-1]
            current_summary_stats_row['val_loss'] = logs['val']['loss'][-1]
            current_summary_stats_row['test_loss'] = logs['test']['loss'][-1]
            current_summary_stats_row['train_recon_loss'] = logs['train']['recon_loss'][-1]
            current_summary_stats_row['val_recon_loss'] = logs['val']['recon_loss'][-1]
            current_summary_stats_row['test_recon_loss'] = logs['test']['recon_loss'][-1]
            current_summary_stats_row['train_kl_post_prior'] = logs['train']['kl_post_prior'][-1]
            current_summary_stats_row['val_kl_post_prior'] = logs['val']['kl_post_prior'][-1]
            current_summary_stats_row['test_kl_post_prior'] = logs['test']['kl_post_prior'][-1]
            current_summary_stats_row['l2_penalty'] = logs['train']['l2_penalty'][-1]

            summary_stats.append(current_summary_stats_row)
            pd.DataFrame(summary_stats).to_csv('%s/summary_stats.csv' % result_path, index=False)
    del(epoch)

    return None

In [None]:
import datetime
import os
import random
import sys

import time
import torch
import torchvision
import numpy as np
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split

# variables for data
p_train = 0.80
p_val = 1.0 - p_train

# variables for VAE architectures
dim_z = 8 # dimension of representations
nf = 64 # a factor to control the overall filter sizes in encoder and decoder networks
decoder_final_activation='none' # the last activation layer in decoder.

# variables for optimization
seed = 0 # the random seed number
num_epoch = 30 # the number of epoch
batch_size = 64 # the number of samples in each mini-batch
num_worker = 8 # the number of CPU cores
beta_recon = 1.0 # the coefficient of the log-likelihood
beta_kl = 1.0 # the coefficient of the KL-penalty term in ELBOs
Adam_beta1 = 0.5 # beta1 for Adam optimizer
Adam_beta2 = 0.999 # beta2 for Adam optimizer
weight_decay = 5e-6 # the coefficient of the half of L2 penalty term
init_lr = 2e-4 # the initial learning rate
lr_milestones = [10, 20] # the epochs to reduce the learning rate
lr_gamma = 0.5 # the multiplier for each time learning rate is reduced
val_period = 1 # the frequency of evaluating trained models
recon_error = 'mse' # the type of reconstruction error (supports 'mse' and 'likelihood')
dtype = torch.float32 # the data type
result_path = None # the directory where results are saved. Its default value is results/vae-time=month-day-hour-min-sec.

# < reproducibility >
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# < load and split MNIST data >
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# download and load the MNIST dataset
train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transform, download=True)

# split the official MNIST training dataset into training and validation
num_train = int(p_train * len(train_dataset))
num_val = len(train_dataset) - num_train
train_dataset, val_dataset = random_split(train_dataset, [num_train, num_val])

# Function to extract and convert data from a dataset
def extract_data(dataset):
    # Because dataset is a Subset, we access .dataset to get original dataset properties
    loader = DataLoader(dataset, batch_size=len(dataset))
    data_iter = iter(loader)
    images, labels = next(data_iter)
    return images.numpy(), labels.numpy()

# Extracting data
x_train, y_train = extract_data(train_dataset)
x_val, y_val = extract_data(val_dataset)
x_test, y_test = extract_data(test_dataset)

In [None]:
# build VAE networks
dim_x = np.shape(x_train)[1:]
vae = model(dim_x=dim_x, dim_z=dim_z, nf=nf, decoder_final_activation=decoder_final_activation)

# train VAE networks. Results will be saved at the result_path
start_time = time.time()
fit(model=vae, x_train=x_train, x_val=x_val, x_test=x_test,
    num_epoch=num_epoch, batch_size=batch_size, num_worker=num_worker, seed=seed,
    beta_recon=beta_recon, beta_kl=beta_kl,
    Adam_beta1=Adam_beta1, Adam_beta2=Adam_beta2, weight_decay=weight_decay,
    init_lr=init_lr, lr_milestones=lr_milestones, lr_gamma=lr_gamma,
    val_period=val_period, recon_error=recon_error,
    dtype=dtype, result_path=result_path)
end_time = time.time()
training_time = end_time - start_time
print(f"Total training time: {training_time:.2f} seconds")