# Create a VAE model
For 2D input images

In [5]:
import torch
import torch.nn as nn
import math

depth = 64      # initial depth to convolve channels into
n_channels = 3  # number of channels (RGB)
filt_size = 4   # convolution filter size
stride = 2      # stride for conv
pad = 1         # padding added for conv

class VAE2D(nn.Module):
    def __init__(self, img_size, n_latent=300):
        
        # Model setup
        #############
        super(VAE2D, self).__init__()
        self.n_latent = n_latent
        n = math.log2(img_size)
        assert n == round(n), 'Image size must be a power of 2'  # restrict image input sizes permitted
        assert n >= 3, 'Image size must be at least 8'           # low dimensional data won't work well
        n = int(n)

        # Encoder - first half of VAE
        #############################
        self.encoder = nn.Sequential()  
        # input: n_channels x img_size x img_size
        # ouput: depth x conv_img_size^2
        # conv_img_size = (img_size - filt_size + 2 * pad) / stride + 1
        self.encoder.add_module('input-conv', nn.Conv2d(n_channels, depth, filt_size, stride, pad,
                                                        bias=True))
        self.encoder.add_module('input-relu', nn.ReLU(inplace=True))
        
        # Add conv layer for each power of 2 over 3 (min size)
        # Pyramid strategy with batch normalization added
        for i in range(n - 3):
            # input: depth x conv_img_size^2
            # output: o_depth x conv_img_size^2
            # i_depth = o_depth of previous layer
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i + 1)
            self.encoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.Conv2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.encoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.encoder.add_module(f'pyramid_{o_depth}_relu',
                                    nn.ReLU(inplace=True))
        
        # Latent representation
        #######################
        # Convolve the encoded image into the latent space, once for mu and once for logvar
        max_depth = depth * 2 ** (n - 3)
        self.conv_mu = nn.Conv2d(max_depth, n_latent, filt_size)      # return the mean of the latent space 
        self.conv_logvar = nn.Conv2d(max_depth, n_latent, filt_size)  # return the log variance of the same
        
        
        # Decoder - second half of VAE
        ##############################
        self.decoder = nn.Sequential()
        # input: max_depth x conv_img_size^2 (8 x 8)  TODO double check sizes
        # output: n_latent x conv_img_size^2 (8 x 8)
        # default stride=1, pad=0 for this layer
        self.decoder.add_module('input-conv', nn.ConvTranspose2d(n_latent, max_depth, filt_size, bias=True))
        self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(max_depth))
        self.decoder.add_module('input-relu', nn.ReLU(inplace=True))
    
        # Reverse the convolution pyramids used in the encoder
        for i in range(n - 3, 0, -1):
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i - 1)
            self.decoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.ConvTranspose2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.decoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.decoder.add_module(f'pyramid_{o_depth}_relu', nn.ReLU(inplace=True))
        
        # Final transposed convolution to return to img_size
        # Final activation is tanh instead of relu to allow negative pixel output
        self.decoder.add_module('output-conv', nn.ConvTranspose2d(depth, n_channels,
                                                                  filt_size, stride, pad, bias=True))
        self.decoder.add_module('output-tanh', nn.Tanh())

        # Model weights init
        ####################
        # Randomly initialize the model weights using kaiming method
        # Reference: "Delving deep into rectifiers: Surpassing human-level
        # performance on ImageNet classification" - He, K. et al. (2015)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def encode(self, imgs):
        """
        Encode the images into latent space vectors (mean and log variance representation)
        input:  imgs   [batch_size, 3, 256, 256]
        output: mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        """
        output = self.encoder(imgs)
        output = output.squeeze(-1).squeeze(-1)
        return [self.conv_mu(output), self.conv_logvar(output)]

    def generate(self, mu, logvar):
        """
        Generates a random latent vector using the trained mean and log variance representation
        input:  mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        output: gen    [batch_size, n_latent, 1, 1]
        """
        std = torch.exp(0.5 * logvar)
        gen = torch.randn_like(std)
        return gen.mul(std).add_(mu)

    def decode(self, gen):
        """
        Restores an image representation from the generated latent vector
        input:  gen      [batch_size, n_latent, 1, 1]
        output: gen_imgs [batch_size, 3, 256, 256]
        """
        return self.decoder(gen)

    def forward(self, imgs):
        """
        Generates reconstituted images from input images based on learned representation
        input: imgs     [batch_size, 3, 256, 256]
        ouput: gen_imgs [batch_size, 3, 256, 256]
               mu       [batch_size, n_latent]
               logvar   [batch_size, n_latent]
        """
        mu, logvar = self.encode(imgs)
        gen = self.generate(mu, logvar)
        return (self.decode(gen),
                mu.squeeze(-1).squeeze(-1),
                logvar.squeeze(-1).squeeze(-1))


In [6]:
model = VAE2D(256)

In [7]:
model

VAE2D(
  (encoder): Sequential(
    (input-conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (input-relu): ReLU(inplace)
    (pyramid_64-128_conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_128_batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_128_relu): ReLU(inplace)
    (pyramid_128-256_conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_256_batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_256_relu): ReLU(inplace)
    (pyramid_256-512_conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_512_batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_512_relu): ReLU(inplace)
    (pyramid_512-1024_conv): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_1024_batchno

In [None]:
$PYTHONBIN train.py --data ${DATA_DIR} --cuda \
    --epochs 40 --lr 1e-4 --batch_size 32 --out_dir ${EXP_DIR}/NV_kl0.01 \
    --image_size 128 --kl_weight 0.01

In [None]:
import os
from pathlib import Path
import torch
import argparse
from loss import VAELoss
from torchvision.utils import make_grid
from utilities import trainVAE, validateVAE
from model import VAE
from dataloader import load_vae_train_datasets
from tensorboardX import SummaryWriter
import numpy as np

# Model
desc = 'VAE for detecting anomalies in 2D images'
data_path = Path('data/NV_outlier/')
img_size = 256

# Training
epochs = 10
lr = 1e-4        # learning rate
lr_decay = 0.1
schedule = 5     # decrease lr after this many epochs
batch_size = 64
kl = 0.1         # weight of the kl term

# Checkpoints and logging
save_path = Path('model/')  # where model and results will be saved
load_path = None            # checkpoint to resume from (default None)
log_freq = 10               # print status after this many batches


args = parser.parse_args()

# load checkpoint
if args.resume is not None:
    checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
    print("checkpoint loaded!")
    print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'], checkpoint['epoch']))

# model
model = VAE(args.image_size)
if args.resume is not None:
    model.load_state_dict(checkpoint['state_dict'])

# criterion
criterion = VAELoss(size_average=True, kl_weight=args.kl_weight)
if args.cuda is True:
    model = model.cuda()
    criterion = criterion.cuda()

# load data
train_loader, val_loader = load_vae_train_datasets(input_size=args.image_size,
                                                   data=args.data,
                                                   batch_size=args.batch_size)

# load optimizer and scheduler
opt = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.999))
if args.resume is not None and not args.reset_opt:
    opt.load_state_dict(checkpoint['optimizer'])

scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=args.schedule,
                                                 gamma=args.lr_decay)

# make output dir
if os.path.isdir(args.out_dir):
    print("{} already exists!".format(args.out_dir))
os.mkdir(args.out_dir)

# save args
args_dict = vars(args)
with open(os.path.join(args.out_dir, 'config.txt'), 'w') as f:
    for k in args_dict.keys():
        f.write("{}:{}\n".format(k, args_dict[k]))
writer = SummaryWriter(log_dir=os.path.join(args.out_dir, 'logs'))

# main loop
best_loss = np.inf
for epoch in range(args.epochs):
    # train for one epoch
    scheduler.step()
    train_loss, train_kl, train_reconst_logp = trainVAE(train_loader, model, criterion, opt, epoch, args)
    writer.add_scalar('train_elbo', -train_loss, global_step=epoch + 1)
    writer.add_scalar('train_kl', train_kl, global_step=epoch + 1)
    writer.add_scalar('train_reconst_logp', train_reconst_logp, global_step=epoch + 1)

    # evaluate on validation set
    with torch.no_grad():
        val_loss, val_kl, val_reconst_logp = validateVAE(val_loader, model, criterion, args)
        writer.add_scalar('val_elbo', -val_loss, global_step=epoch + 1)
        writer.add_scalar('val_kl', val_kl, global_step=epoch + 1)
        writer.add_scalar('val_reconst_logp', val_reconst_logp, global_step=epoch + 1)

    # remember best acc and save checkpoint
    if val_loss < best_loss:
        print('checkpointed!')
        best_loss = val_loss
        save_dict = {'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'val_loss': val_loss,
                     'optimizer': opt.state_dict()}
        save_path = os.path.join(args.out_dir, 'best_model.pth.tar')
        torch.save(save_dict, save_path)
    print('curr lowest val loss {}'.format(best_loss))

    # visualize reconst and free sample
    print("plotting imgs...")
    with torch.no_grad():
        val_iter = val_loader.__iter__()

        # reconstruct 25 imgs
        imgs = val_iter._get_batch()[1][0][:25]
        if args.cuda:
            imgs = imgs.cuda()
        imgs_reconst, mu, logvar = model(imgs)

        # sample 25 imgs
        noises = torch.randn(25, model.nz, 1, 1)
        if args.cuda:
            noises = noises.cuda()
        samples = model.decode(noises)

        def write_image(tag, images):
            """
            write the resulting imgs to tensorboard.
            :param tag: The tag for tensorboard
            :param images: the torch tensor with range (-1, 1). [9, 3, 256, 256]
            """
            # make it from 0 to 255
            images = (images + 1) / 2
            grid = make_grid(images, nrow=5, padding=20)
            writer.add_image(tag, grid.detach(), global_step=epoch + 1)

        write_image("origin", imgs)
        write_image("reconst", imgs_reconst)
        write_image("samples", samples)
        print('done')

import ipdb
