# Imports

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os, math, time
from pathlib import Path
from datetime import date

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

from sklearn.metrics import roc_auc_score
from itertools import product
import scipy.stats as stats

In [3]:
from IPython.core.debugger import set_trace

# Define a VAE Model
Initially designed for 2D input images.
Based on this paper: https://arxiv.org/abs/1807.01349

In [4]:
depth = 64      # initial depth to convolve channels into
n_channels = 3  # number of channels (RGB)
filt_size = 4   # convolution filter size
stride = 2      # stride for conv
pad = 1         # padding added for conv

class VAE2D(nn.Module):
    def __init__(self, img_size, n_latent=300):
        
        # Model setup
        #############
        super(VAE2D, self).__init__()
        self.n_latent = n_latent
        n = math.log2(img_size)
        assert n == round(n), 'Image size must be a power of 2'  # restrict image input sizes permitted
        assert n >= 3, 'Image size must be at least 8'           # low dimensional data won't work well
        n = int(n)

        # Encoder - first half of VAE
        #############################
        self.encoder = nn.Sequential()  
        # input: n_channels x img_size x img_size
        # ouput: depth x conv_img_size^2
        # conv_img_size = (img_size - filt_size + 2 * pad) / stride + 1
        self.encoder.add_module('input-conv', nn.Conv2d(n_channels, depth, filt_size, stride, pad,
                                                        bias=True))
        self.encoder.add_module('input-relu', nn.ReLU(inplace=True))
        
        # Add conv layer for each power of 2 over 3 (min size)
        # Pyramid strategy with batch normalization added
        for i in range(n - 3):
            # input: depth x conv_img_size^2
            # output: o_depth x conv_img_size^2
            # i_depth = o_depth of previous layer
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i + 1)
            self.encoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.Conv2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.encoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.encoder.add_module(f'pyramid_{o_depth}_relu',
                                    nn.ReLU(inplace=True))
        
        # Latent representation
        #######################
        # Convolve the encoded image into the latent space, once for mu and once for logvar
        max_depth = depth * 2 ** (n - 3)
        self.conv_mu = nn.Conv2d(max_depth, n_latent, filt_size)      # return the mean of the latent space 
        self.conv_logvar = nn.Conv2d(max_depth, n_latent, filt_size)  # return the log variance of the same
        
        
        # Decoder - second half of VAE
        ##############################
        self.decoder = nn.Sequential()
        # input: max_depth x conv_img_size^2 (8 x 8)
        # output: n_latent x conv_img_size^2 (8 x 8)
        # default stride=1, pad=0 for this layer
        self.decoder.add_module('input-conv', nn.ConvTranspose2d(n_latent, max_depth, filt_size, bias=True))
        self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(max_depth))
        self.decoder.add_module('input-relu', nn.ReLU(inplace=True))
    
        # Reverse the convolution pyramids used in the encoder
        for i in range(n - 3, 0, -1):
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i - 1)
            self.decoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.ConvTranspose2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.decoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.decoder.add_module(f'pyramid_{o_depth}_relu', nn.ReLU(inplace=True))
        
        # Final transposed convolution to return to img_size
        # Final activation is tanh instead of relu to allow negative pixel output
        self.decoder.add_module('output-conv', nn.ConvTranspose2d(depth, n_channels,
                                                                  filt_size, stride, pad, bias=True))
        self.decoder.add_module('output-tanh', nn.Tanh())

        # Model weights init
        ####################
        # Randomly initialize the model weights using kaiming method
        # Reference: "Delving deep into rectifiers: Surpassing human-level
        # performance on ImageNet classification" - He, K. et al. (2015)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def encode(self, imgs):
        """
        Encode the images into latent space vectors (mean and log variance representation)
        input:  imgs   [batch_size, 3, 256, 256]
        output: mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        """
        output = self.encoder(imgs)
        output = output.squeeze(-1).squeeze(-1)
        return [self.conv_mu(output), self.conv_logvar(output)]

    def generate(self, mu, logvar):
        """
        Generates a random latent vector using the trained mean and log variance representation
        input:  mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        output: gen    [batch_size, n_latent, 1, 1]
        """
        std = torch.exp(0.5 * logvar)
        gen = torch.randn_like(std)
        return gen.mul(std).add_(mu)

    def decode(self, gen):
        """
        Restores an image representation from the generated latent vector
        input:  gen      [batch_size, n_latent, 1, 1]
        output: gen_imgs [batch_size, 3, 256, 256]
        """
        return self.decoder(gen)

    def forward(self, imgs):
        """
        Generates reconstituted images from input images based on learned representation
        input: imgs     [batch_size, 3, 256, 256]
        ouput: gen_imgs [batch_size, 3, 256, 256]
               mu       [batch_size, n_latent]
               logvar   [batch_size, n_latent]
        """
        mu, logvar = self.encode(imgs)
        gen = self.generate(mu, logvar)
        for tensor in (mu, logvar):
            tensor = tensor.squeeze(-1).squeeze(-1)
        return self.decode(gen), mu, logvar


# Define a loss function
Must be suitable for anomaly detection by recreation similarity

In [5]:
class VAE2DLoss(nn.Module):

    def __init__(self, kl_weight=1):
        super(VAE2DLoss, self).__init__()
        self.kl_weight = kl_weight

    def forward(self, gen_imgs, imgs, mu, logvar):
        """
        input:  gen_imgs [batch_size, n_channels, img_size, img_size]
                imgs     [batch_size, n_channels, img_size, img_size]
                mu       [batch_size, n_latent]
                logvar   [batch_size, n_latent]
        output: loss      scalar (-ELBO)
                loss_desc {'KL': KL, 'logp': gen_err}
        """
        batch_size = imgs.shape[0]
        gen_err = (imgs - gen_imgs).pow(2).reshape(batch_size, -1)
        gen_err = 0.5 * torch.sum(gen_err, dim=-1)
        gen_err = torch.mean(gen_err)

        # KL(q || p) = -log_sigma + sigma^2/2 + mu^2/2 - 1/2
        KL = (-logvar + logvar.exp() + mu.pow(2) - 1) * 0.5
        KL = torch.sum(KL, dim=-1)
        KL = torch.mean(KL)

        loss = gen_err + self.kl_weight * KL
        return loss, {'KL': KL, 'logp': -gen_err}

# Load Data

In [15]:
def load_datasets(img_size, data_path):
    """
    Load the image datasets from vae_train and vae_test
    Transform to correct image size
    """
    
    train_path = data_path / 'train/train/'
    val_path = data_path / 'train/val/'
    test_path = data_path / 'test/'
    
    norm_args = {'mean': [0.5] * n_channels,
                 'std': [0.5] * n_channels}
    jitter_args = {'brightness': 0.1,
                   'contrast': 0.1,
                   'saturation': 0.1}  # hue unchanged
    
    train_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomCrop(img_size),          # vary horizontal position
        transforms.RandomHorizontalFlip(p=0.25),  # vary photo orientation
        transforms.RandomVerticalFlip(p=0.25),
        transforms.ColorJitter(**jitter_args),    # vary photo lighting
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])
    
    test_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),  # assume center is most important
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])

    train_ds = datasets.ImageFolder(train_path, train_transform)
    val_ds = datasets.ImageFolder(val_path, test_transform)
    test_ds = datasets.ImageFolder(test_path, test_transform)
    
    loader_args = {'shuffle': True,
                   'num_workers': 4}
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, **loader_args)
    val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, **loader_args)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, ** loader_args)
    
    return train_dl, val_dl, test_dl

In [16]:
# Model/Data parameters
desc = 'skin'
data_path = Path('data/NV_outlier/')
img_size = 128
n_channels = 3

# Training parameters
epochs = 40
lr = 1e-4                # learning rate
lr_decay = 0.1           # lr decay factor
kl_weight = 0.01         # weighted factor of the KL term
schedule = [10, 20, 30]  # decrease lr at these epochs
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Checkpoints/Logging parameters
save_path = Path(f"models/{date.today().strftime('%y%m%d')}-{desc}/")
# checkpoint to resume from (default None)
load_path = Path('models/190124-skin/best_model.pth.tar')
log_freq = 10            # print status after this many batches

In [17]:
save_path

PosixPath('models/190127-skin')

In [18]:
train_dl, val_dl, test_dl = load_datasets(img_size, data_path)

In [19]:
print(len(train_dl), len(val_dl), len(test_dl))

6370 339 209


In [24]:
train_dl[0][0].shape

torch.Size([3, 128, 128])

# Build the Model

In [15]:
# Create model
model = VAE2D(img_size)

# Load optimizer and scheduler
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, schedule, lr_decay)

# Load checkpoint if any
if load_path is not None:
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("Checkpoint loaded")
    print(f"Validation loss: {checkpoint['val_loss']}")
    print(f"Epoch: {checkpoint['epoch']}")

# Set loss criterion
criterion = VAE2DLoss(kl_weight=kl_weight)

# Move to GPU
model = model.to(device)
criterion = criterion.to(device)

Checkpoint loaded
Validation loss: 148.0490126176314
Epoch: 33


# Train the Model

In [13]:
# Make save directory
if save_path.is_dir():
    print(f"Folder {save_path} already exists")
else:
    os.mkdir(save_path)

Folder models/190124-skin already exists


In [None]:
# Convenience classes
class StopWatch(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.start = time.time()
        self.lap_start = time.time()
        self.elapsed = []
    
    def lap(self):
        self.elapsed.append(time.time() - self.lap_start)
    
class AvgTracker(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.sum = 0
        self.avg = 0
        self.cnt = 0

    def update(self, val):
        self.val = val
        self.sum += val
        self.cnt += 1
        self.avg = self.sum / self.cnt

In [None]:
def trainVAE2D(dl):
    loss_tracker = AvgTracker()
    kl_tracker = AvgTracker()
    logp_tracker = AvgTracker()
    timer = StopWatch()
    
    for i, (X, _) in tqdm(enumerate(dl)):
        
        X = X.to(device)
        timer.lap()  # load time
        
        # Generate images and compute loss
        X_hat, mu, logvar = model(X)
        loss, loss_desc = criterion(X_hat, X, mu, logvar)
        timer.lap()  # gen time
        
        loss_tracker.update(loss.item())
        kl_tracker.update(loss_desc['KL'].item())
        logp_tracker.update(loss_desc['logp'].item())
        
        if model.training:
            # Update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            timer.lap()  # backprop time
        
        if (i + 1) % log_freq == 0:
            # Print progress
            print(f'Epoch: {epoch + 1} ({i + 1}/{len(dl)})')
            print(f'\tData load time: {timer.elapsed[0]:.3f} sec')
            print(f'\tGeneration time: {timer.elapsed[1]:.3f} sec')
            if model.training:
                print(f'\tBackprop time: {timer.elapsed[2]:.3f} sec')
            print(f'\tLog probability: {logp_tracker.val:.4f} '
                  f'(avg {logp_tracker.avg:.4f})')
            print(f'\tKL: {kl_tracker.val:.4f} (avg {kl_tracker.avg:.4f})')
            print(f'\tLoss: {loss_tracker.val:.4f} (avg {loss_tracker.avg:.4f})')

    return loss_tracker.avg, kl_tracker.avg, logp_tracker.avg

In [None]:
# Main loop
best_loss = np.inf
for epoch in range(epochs):

    model.train()
    scheduler.step()
    train_loss, train_kl, train_logp = trainVAE2D(train_dl)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_kl, val_logp = trainVAE2D(val_dl)

    # Report training progress to user
    print(f'Lowest validation loss: {best_loss:.4f}')
    if val_loss < best_loss:
        print('Saving checkpoint..')
        best_loss = val_loss
        save_dict = {'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'val_loss': val_loss,
                     'optimizer': optimizer.state_dict()}
        path = save_path / 'best_model.pth.tar'
        torch.save(save_dict, path)
    

### ADD VISUALIZATIONS

In [None]:
# Visualize generated images and input sample images
with torch.no_grad():

    val_iter = iter(val_dl)

    # Generate 25 images
    imgs = val_iter._get_batch()[1][0][:25]
    imgs = imgs.to(device)
    gen_imgs, mu, logvar = model(imgs)
    
    # Scale images back to 0-1
    imgs = (imgs + 1) / 2
    grid = make_grid(imgs, nrow=5, padding=20)
    gen_imgs = (gen_imgs + 1) / 2
    gen_grid = make_grid(gen_imgs, nrow=5, padding=20)
    

In [None]:
import matplotlib.pyplot as plt
grid = grid.cpu()
gen_grid = gen_grid.cpu()

In [None]:
plt.imshow(grid.permute(1, 2, 0))

In [None]:
plt.imshow(gen_grid.permute(1, 2, 0))

In [None]:
mu = mu.cpu()
std = (0.5 * logvar).exp().cpu()

In [None]:
for i in np.random.choice(range(mu.shape[0]), 5):
    plt.figure()
    mu_eg = mu[i, :].squeeze(-1).squeeze(-1)
    plt.plot(mu_eg.numpy())
    plt.figure()
    std_eg = std[i, :].squeeze(-1).squeeze(-1)
    plt.plot(std_eg.numpy())

In [None]:
with torch.no_grad():
    # Generate some random images
    noises = torch.randn(25, model.n_latent, 1, 1)
    noises = noises.to(device)
    samples = model.decode(noises)
    
    samples = (samples + 1) / 2
    sample_grid = make_grid(samples, nrow=5, padding=20).cpu()
    
    plt.imshow(sample_grid.permute(1, 2, 0))  # easy way to swapaxes