# Imports

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os, math, time
from pathlib import Path
from datetime import date

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

In [3]:
from IPython.core.debugger import set_trace

# Define a VAE Model
Initially designed for 2D input images.
Based on this paper: https://arxiv.org/abs/1807.01349

In [47]:
depth = 64      # initial depth to convolve channels into
n_channels = 3  # number of channels (RGB)
filt_size = 4   # convolution filter size
stride = 2      # stride for conv
pad = 1         # padding added for conv

class VAE2D(nn.Module):
    def __init__(self, img_size, n_latent=300):
        
        # Model setup
        #############
        super(VAE2D, self).__init__()
        self.n_latent = n_latent
        n = math.log2(img_size)
        assert n == round(n), 'Image size must be a power of 2'  # restrict image input sizes permitted
        assert n >= 3, 'Image size must be at least 8'           # low dimensional data won't work well
        n = int(n)

        # Encoder - first half of VAE
        #############################
        self.encoder = nn.Sequential()  
        # input: n_channels x img_size x img_size
        # ouput: depth x conv_img_size^2
        # conv_img_size = (img_size - filt_size + 2 * pad) / stride + 1
        self.encoder.add_module('input-conv', nn.Conv2d(n_channels, depth, filt_size, stride, pad,
                                                        bias=True))
        self.encoder.add_module('input-relu', nn.ReLU(inplace=True))
        
        # Add conv layer for each power of 2 over 3 (min size)
        # Pyramid strategy with batch normalization added
        for i in range(n - 3):
            # input: depth x conv_img_size^2
            # output: o_depth x conv_img_size^2
            # i_depth = o_depth of previous layer
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i + 1)
            self.encoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.Conv2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.encoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.encoder.add_module(f'pyramid_{o_depth}_relu',
                                    nn.ReLU(inplace=True))
        
        # Latent representation
        #######################
        # Convolve the encoded image into the latent space, once for mu and once for logvar
        max_depth = depth * 2 ** (n - 3)
        self.conv_mu = nn.Conv2d(max_depth, n_latent, filt_size)      # return the mean of the latent space 
        self.conv_logvar = nn.Conv2d(max_depth, n_latent, filt_size)  # return the log variance of the same
        
        
        # Decoder - second half of VAE
        ##############################
        self.decoder = nn.Sequential()
        # input: max_depth x conv_img_size^2 (8 x 8)  TODO double check sizes
        # output: n_latent x conv_img_size^2 (8 x 8)
        # default stride=1, pad=0 for this layer
        self.decoder.add_module('input-conv', nn.ConvTranspose2d(n_latent, max_depth, filt_size, bias=True))
        self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(max_depth))
        self.decoder.add_module('input-relu', nn.ReLU(inplace=True))
    
        # Reverse the convolution pyramids used in the encoder
        for i in range(n - 3, 0, -1):
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i - 1)
            self.decoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.ConvTranspose2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.decoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.decoder.add_module(f'pyramid_{o_depth}_relu', nn.ReLU(inplace=True))
        
        # Final transposed convolution to return to img_size
        # Final activation is tanh instead of relu to allow negative pixel output
        self.decoder.add_module('output-conv', nn.ConvTranspose2d(depth, n_channels,
                                                                  filt_size, stride, pad, bias=True))
        self.decoder.add_module('output-tanh', nn.Tanh())

        # Model weights init
        ####################
        # Randomly initialize the model weights using kaiming method
        # Reference: "Delving deep into rectifiers: Surpassing human-level
        # performance on ImageNet classification" - He, K. et al. (2015)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def encode(self, imgs):
        """
        Encode the images into latent space vectors (mean and log variance representation)
        input:  imgs   [batch_size, 3, 256, 256]
        output: mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        """
        output = self.encoder(imgs)
        output = output.squeeze(-1).squeeze(-1)
        return [self.conv_mu(output), self.conv_logvar(output)]

    def generate(self, mu, logvar):
        """
        Generates a random latent vector using the trained mean and log variance representation
        input:  mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        output: gen    [batch_size, n_latent, 1, 1]
        """
        std = torch.exp(0.5 * logvar)
        gen = torch.randn_like(std)
        return gen.mul(std).add_(mu)

    def decode(self, gen):
        """
        Restores an image representation from the generated latent vector
        input:  gen      [batch_size, n_latent, 1, 1]
        output: gen_imgs [batch_size, 3, 256, 256]
        """
        return self.decoder(gen)

    def forward(self, imgs):
        """
        Generates reconstituted images from input images based on learned representation
        input: imgs     [batch_size, 3, 256, 256]
        ouput: gen_imgs [batch_size, 3, 256, 256]
               mu       [batch_size, n_latent]
               logvar   [batch_size, n_latent]
        """
        mu, logvar = self.encode(imgs)
        gen = self.generate(mu, logvar)
        for tensor in (mu, logvar):
            tensor = tensor.squeeze(-1).squeeze(-1)
        return self.decode(gen), mu, logvar


In [48]:
model = VAE2D(256)

In [49]:
model

VAE2D(
  (encoder): Sequential(
    (input-conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (input-relu): ReLU(inplace)
    (pyramid_64-128_conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_128_batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_128_relu): ReLU(inplace)
    (pyramid_128-256_conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_256_batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_256_relu): ReLU(inplace)
    (pyramid_256-512_conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_512_batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_512_relu): ReLU(inplace)
    (pyramid_512-1024_conv): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_1024_batchno

# Define a loss function
Must be suitable for anomaly detection by recreation similarity

In [29]:
class VAE2DLoss(nn.Module):

    def __init__(self, kl_weight=1):
        super(VAE2DLoss, self).__init__()
        self.kl_weight = kl_weight

    def forward(self, gen_imgs, imgs, mu, logvar):
        """
        input:  gen_imgs [batch_size, n_channels, img_size, img_size]
                imgs     [batch_size, n_channels, img_size, img_size]
                mu       [batch_size, n_latent]
                logvar   [batch_size, n_latent]
        output: loss      scalar (-ELBO)
                loss_desc {'KL': KL, 'gen_logp': gen_err}
        """
        batch_size = imgs.shape[0]
        gen_err = (imgs - gen_imgs).pow(2).reshape(batch_size, -1)
        gen_err = 0.5 * torch.sum(gen_err, dim=-1)
        gen_err = torch.mean(gen_err)

        # KL(q || p) = -log_sigma + sigma^2/2 + mu^2/2 - 1/2
        KL = (-logvar + logvar.exp() + mu.pow(2) - 1) * 0.5
        KL = torch.sum(KL, dim=-1)
        KL = torch.mean(KL)

        loss = gen_err + self.kl_weight * KL
        return loss, {'KL': KL, 'gen_logp': -gen_err}

# Load Data

In [23]:
def load_datasets(img_size, data_path):
    """
    Load the image datasets from vae_train and vae_test
    Transform to correct image size
    """
    
    train_path = data_path / 'vae_train/train/'
    val_path = data_path / 'vae_train/val/'
    test_path = data_path / 'vae_test/'
    
    norm_args = {'mean': [0.5] * n_channels,
                 'std': [0.5] * n_channels}
    jitter_args = {'brightness': 0.1,
                   'contrast': 0.1,
                   'saturation': 0.1}  # hue unchanged
    
    train_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomCrop(img_size),          # vary horizontal position
        transforms.RandomHorizontalFlip(p=0.25),  # vary photo orientation
        transforms.RandomVerticalFlip(p=0.25),
        transforms.ColorJitter(**jitter_args),    # vary photo lighting
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])
    
    test_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),  # assume center is most important
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])

    train_ds = datasets.ImageFolder(train_path, train_transform)
    val_ds = datasets.ImageFolder(val_path, test_transform)
    test_ds = datasets.ImageFolder(test_path, test_transform)
    
    
    loader_args = {'shuffle': True,
                   'num_workers': 4}
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, **loader_args)
    val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, **loader_args)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, ** loader_args)
    
    return train_dl, val_dl, test_dl

In [32]:
# Model/Data parameters
desc = 'VAE for detecting anomalies in 2D images'
data_path = Path('data/NV_outlier/')
img_size = 128
n_channels = 3

# Training parameters
epochs = 40
lr = 1e-4                # learning rate
lr_decay = 0.1           # lr decay factor
kl_weight = 0.01         # weighted factor of the KL term
schedule = [10, 20, 30]  # decrease lr at these epochs
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Checkpoints/Logging parameters
save_path = Path(f"models/{date.today().strftime('%y%m%d')}/")
load_path = None         # checkpoint to resume from (default None)
log_freq = 10            # print status after this many batches

In [25]:
train_dl, val_dl, test_dl = load_datasets(img_size, data_path)

In [27]:
print(len(train_dl), len(val_dl), len(test_dl))

200 11 958


# Build the Model

In [59]:
# Create model
model = VAE2D(img_size)

# Load optimizer and scheduler
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, schedule, lr_decay)

# Load checkpoint if any
if load_path is not None:
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("Checkpoint loaded")
    print(f"Validation loss: {checkpoint['val_loss']}")
    print(f"Epoch: {checkpoint['epoch']}")

# Set loss criterion
criterion = VAE2DLoss(kl_weight=kl_weight)

# Move to GPU
model = model.to(device)
criterion = criterion.to(device)

# Train the Model

In [60]:
# Make save directory
if save_path.is_dir():
    print(f"Folder {save_path} already exists")
else:
    os.mkdir(save_path)

Folder models/190124 already exists


In [61]:
# TODO - add in logging to Visdom

In [62]:
# Convenience classes
class StopWatch(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.start = time.time()
        self.lap_start = time.time()
        self.elapsed = []
    
    def lap(self):
        self.elapsed.append(time.time() - self.lap_start)
    
class AvgTracker(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.sum = 0
        self.avg = 0
        self.cnt = 0

    def update(self, val):
        self.val = val
        self.sum += val
        self.cnt += 1
        self.avg = self.sum / self.cnt

In [65]:
def trainVAE2D(dl):
    loss_tracker = AvgTracker()
    kl_tracker = AvgTracker()
    logp_tracker = AvgTracker()
    timer = StopWatch()
    
    for i, (X, _) in tqdm(enumerate(dl)):
        
        X = X.to(device)
        timer.lap()  # load time
        
        # Generate images and compute loss
        X_hat, mu, logvar = model(X)
        loss, loss_desc = criterion(X_hat, X, mu, logvar)
        timer.lap()  # gen time
        
        loss_tracker.update(loss.item())
        kl_tracker.update(loss_desc['KL'].item())
        logp_tracker.update(loss_desc['gen_logp'].item())
        
        if model.training:
            # Update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            timer.lap()  # backprop time
        
        if i % log_freq == 0:
            # Print progress
            print(f'Epoch: {epoch + 1} ({i}/{len(dl)})')
            print(f'\tData load time: {timer.elapsed[0]:.3f} sec')
            print(f'\tGeneration time: {timer.elapsed[1]:.3f} sec')
            if model.training:
                print(f'\tBackprop time: {timer.elapsed[2]:.3f} sec')
            print(f'\tLog probability: {logp_tracker.val:.4f} '
                  f'(avg {logp_tracker.avg:.4f})')
            print(f'\tKL: {kl_tracker.val:.4f} (avg {kl_tracker.avg:.4f})')
            print(f'\tLoss: {loss_tracker.val:.4f} (avg {loss_tracker.avg:.4f})')

    return loss_tracker.avg, kl_tracker.avg, logp_tracker.avg


def testVAE2D(test_dl):
    abnormal_loss_tracker = AvgTracker()
    normal_loss_tracker = AvgTracker()

    model.eval()
    for i, (X, y) in tqdm(enumerate(test_dl)):

        X = X.to(device)
        X_hat, mu, logvar = model(X)
        loss, loss_desc = criterion(X_hat, X, mu, logvar)

        # Normal
        if target.item() == 1:
           normal_loss_tracker.update(loss.item())
        # Abnormal
        else:
           abnormal_loss_tracker.update(loss.item())

    return normal_loss_tracker.avg, abnormal_loss_tracker.avg


In [67]:
# Main loop
best_loss = np.inf
for epoch in range(epochs):

    model.train()
    scheduler.step()
    train_loss, train_kl, train_logp = trainVAE2D(train_dl)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_kl, val_logp = trainVAE2D(val_dl)

    # Report training progress to user
    print(f'Lowest validation loss: {best_loss:.4f}')
    if val_loss < best_loss:
        print('Saving checkpoint..')
        best_loss = val_loss
        save_dict = {'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'val_loss': val_loss,
                     'optimizer': optimizer.state_dict()}
        path = save_path / 'best_model.pth.tar'
        torch.save(save_dict, save_path)
    



0it [00:00, ?it/s][A[A

1it [00:00,  1.06it/s][A[A

Epoch: 0 (0/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -617.5640 (avg -617.5640)
	KL: 9.3432 (avg 9.3432)
	Loss: 617.6575 (avg 617.6575)




2it [00:01,  1.26it/s][A[A

3it [00:01,  1.48it/s][A[A

4it [00:02,  1.68it/s][A[A

5it [00:02,  1.87it/s][A[A

6it [00:02,  2.02it/s][A[A

7it [00:03,  2.16it/s][A[A

8it [00:03,  2.26it/s][A[A

9it [00:04,  2.34it/s][A[A

10it [00:04,  2.41it/s][A[A

11it [00:04,  2.45it/s][A[A

Epoch: 0 (10/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -730.0179 (avg -686.9929)
	KL: 9.0999 (avg 9.6432)
	Loss: 730.1089 (avg 687.0893)




12it [00:05,  2.47it/s][A[A

13it [00:05,  2.51it/s][A[A

14it [00:06,  2.54it/s][A[A

15it [00:06,  2.54it/s][A[A

16it [00:06,  2.55it/s][A[A

17it [00:07,  2.56it/s][A[A

18it [00:07,  2.55it/s][A[A

19it [00:08,  2.56it/s][A[A

20it [00:08,  2.56it/s][A[A

21it [00:08,  2.55it/s][A[A

Epoch: 0 (20/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -640.7998 (avg -684.3577)
	KL: 8.8424 (avg 9.5287)
	Loss: 640.8882 (avg 684.4530)




22it [00:09,  2.56it/s][A[A

23it [00:09,  2.57it/s][A[A

24it [00:10,  2.56it/s][A[A

25it [00:10,  2.57it/s][A[A

26it [00:10,  2.56it/s][A[A

27it [00:11,  2.56it/s][A[A

28it [00:11,  2.54it/s][A[A

29it [00:11,  2.55it/s][A[A

30it [00:12,  2.54it/s][A[A

31it [00:12,  2.55it/s][A[A

Epoch: 0 (30/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -493.6983 (avg -645.2374)
	KL: 8.8476 (avg 9.3283)
	Loss: 493.7868 (avg 645.3307)




32it [00:13,  2.55it/s][A[A

33it [00:13,  2.55it/s][A[A

34it [00:13,  2.56it/s][A[A

35it [00:14,  2.57it/s][A[A

36it [00:14,  2.56it/s][A[A

37it [00:15,  2.55it/s][A[A

38it [00:15,  2.57it/s][A[A

39it [00:15,  2.54it/s][A[A

40it [00:16,  2.55it/s][A[A

41it [00:16,  2.55it/s][A[A

Epoch: 0 (40/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -739.5546 (avg -627.5278)
	KL: 8.6113 (avg 9.1869)
	Loss: 739.6407 (avg 627.6197)




42it [00:17,  2.54it/s][A[A

43it [00:17,  2.56it/s][A[A

44it [00:17,  2.56it/s][A[A

45it [00:18,  2.55it/s][A[A

46it [00:18,  2.56it/s][A[A

47it [00:19,  2.56it/s][A[A

48it [00:19,  2.55it/s][A[A

49it [00:19,  2.55it/s][A[A

50it [00:20,  2.55it/s][A[A

51it [00:20,  2.54it/s][A[A

Epoch: 0 (50/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -464.6272 (avg -612.7417)
	KL: 8.5366 (avg 9.3631)
	Loss: 464.7125 (avg 612.8354)




52it [00:20,  2.56it/s][A[A

53it [00:21,  2.56it/s][A[A

54it [00:21,  2.56it/s][A[A

55it [00:22,  2.57it/s][A[A

56it [00:22,  2.57it/s][A[A

57it [00:22,  2.56it/s][A[A

58it [00:23,  2.56it/s][A[A

59it [00:23,  2.55it/s][A[A

60it [00:24,  2.55it/s][A[A

61it [00:24,  2.55it/s][A[A

Epoch: 0 (60/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -514.1101 (avg -600.4136)
	KL: 8.6883 (avg 9.3336)
	Loss: 514.1970 (avg 600.5070)




62it [00:24,  2.55it/s][A[A

63it [00:25,  2.55it/s][A[A

64it [00:25,  2.56it/s][A[A

65it [00:26,  2.56it/s][A[A

66it [00:26,  2.56it/s][A[A

67it [00:26,  2.55it/s][A[A

68it [00:27,  2.55it/s][A[A

69it [00:27,  2.53it/s][A[A

70it [00:28,  2.54it/s][A[A

71it [00:28,  2.54it/s][A[A

Epoch: 0 (70/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -533.4414 (avg -589.7215)
	KL: 8.8672 (avg 9.3440)
	Loss: 533.5301 (avg 589.8149)




72it [00:28,  2.53it/s][A[A

73it [00:29,  2.54it/s][A[A

74it [00:29,  2.53it/s][A[A

75it [00:30,  2.54it/s][A[A

76it [00:30,  2.56it/s][A[A

77it [00:30,  2.55it/s][A[A

78it [00:31,  2.55it/s][A[A

79it [00:31,  2.55it/s][A[A

80it [00:31,  2.55it/s][A[A

81it [00:32,  2.54it/s][A[A

Epoch: 0 (80/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -447.1297 (avg -578.4395)
	KL: 8.7112 (avg 9.2855)
	Loss: 447.2168 (avg 578.5323)




82it [00:32,  2.55it/s][A[A

83it [00:33,  2.53it/s][A[A

84it [00:33,  2.54it/s][A[A

85it [00:33,  2.53it/s][A[A

86it [00:34,  2.53it/s][A[A

87it [00:34,  2.54it/s][A[A

88it [00:35,  2.55it/s][A[A

89it [00:35,  2.55it/s][A[A

90it [00:35,  2.56it/s][A[A

91it [00:36,  2.56it/s][A[A

Epoch: 0 (90/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -538.3646 (avg -564.0941)
	KL: 9.1508 (avg 9.3588)
	Loss: 538.4561 (avg 564.1877)




92it [00:36,  2.55it/s][A[A

93it [00:37,  2.55it/s][A[A

94it [00:37,  2.55it/s][A[A

95it [00:37,  2.53it/s][A[A

96it [00:38,  2.55it/s][A[A

97it [00:38,  2.54it/s][A[A

98it [00:39,  2.54it/s][A[A

99it [00:39,  2.55it/s][A[A

100it [00:39,  2.55it/s][A[A

101it [00:40,  2.56it/s][A[A

Epoch: 0 (100/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -509.7361 (avg -551.2314)
	KL: 8.2835 (avg 9.2989)
	Loss: 509.8190 (avg 551.3244)




102it [00:40,  2.55it/s][A[A

103it [00:41,  2.54it/s][A[A

104it [00:41,  2.53it/s][A[A

105it [00:41,  2.54it/s][A[A

106it [00:42,  2.52it/s][A[A

107it [00:42,  2.53it/s][A[A

108it [00:42,  2.51it/s][A[A

109it [00:43,  2.50it/s][A[A

110it [00:43,  2.50it/s][A[A

111it [00:44,  2.49it/s][A[A

Epoch: 0 (110/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -437.3727 (avg -541.0066)
	KL: 8.6892 (avg 9.2484)
	Loss: 437.4595 (avg 541.0991)




112it [00:44,  2.50it/s][A[A

113it [00:44,  2.52it/s][A[A

114it [00:45,  2.49it/s][A[A

115it [00:45,  2.51it/s][A[A

116it [00:46,  2.50it/s][A[A

117it [00:46,  2.49it/s][A[A

118it [00:47,  2.50it/s][A[A

119it [00:47,  2.49it/s][A[A

120it [00:47,  2.50it/s][A[A

121it [00:48,  2.50it/s][A[A

Epoch: 0 (120/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -342.2521 (avg -532.2367)
	KL: 8.3820 (avg 9.1901)
	Loss: 342.3359 (avg 532.3286)




122it [00:48,  2.49it/s][A[A

123it [00:49,  2.51it/s][A[A

124it [00:49,  2.50it/s][A[A

125it [00:49,  2.50it/s][A[A

126it [00:50,  2.50it/s][A[A

127it [00:50,  2.50it/s][A[A

128it [00:50,  2.52it/s][A[A

129it [00:51,  2.52it/s][A[A

130it [00:51,  2.53it/s][A[A

131it [00:52,  2.55it/s][A[A

Epoch: 0 (130/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -353.6423 (avg -522.4579)
	KL: 8.2746 (avg 9.1500)
	Loss: 353.7250 (avg 522.5494)




132it [00:52,  2.54it/s][A[A

133it [00:52,  2.55it/s][A[A

134it [00:53,  2.54it/s][A[A

135it [00:53,  2.53it/s][A[A

136it [00:54,  2.54it/s][A[A

137it [00:54,  2.54it/s][A[A

138it [00:54,  2.52it/s][A[A

139it [00:55,  2.52it/s][A[A

140it [00:55,  2.53it/s][A[A

141it [00:56,  2.50it/s][A[A

Epoch: 0 (140/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -389.0826 (avg -515.1274)
	KL: 10.7975 (avg 9.1028)
	Loss: 389.1906 (avg 515.2185)




142it [00:56,  2.52it/s][A[A

143it [00:56,  2.50it/s][A[A

144it [00:57,  2.51it/s][A[A

145it [00:57,  2.51it/s][A[A

146it [00:58,  2.50it/s][A[A

147it [00:58,  2.52it/s][A[A

148it [00:58,  2.51it/s][A[A

149it [00:59,  2.52it/s][A[A

150it [00:59,  2.53it/s][A[A

151it [01:00,  2.53it/s][A[A

Epoch: 0 (150/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -335.9324 (avg -507.1036)
	KL: 8.4285 (avg 9.0420)
	Loss: 336.0167 (avg 507.1940)




152it [01:00,  2.54it/s][A[A

153it [01:00,  2.54it/s][A[A

154it [01:01,  2.54it/s][A[A

155it [01:01,  2.53it/s][A[A

156it [01:02,  2.53it/s][A[A

157it [01:02,  2.53it/s][A[A

158it [01:02,  2.52it/s][A[A

159it [01:03,  2.52it/s][A[A

160it [01:03,  2.50it/s][A[A

161it [01:04,  2.52it/s][A[A

Epoch: 0 (160/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -503.8904 (avg -501.4184)
	KL: 7.9691 (avg 8.9870)
	Loss: 503.9701 (avg 501.5082)




162it [01:04,  2.51it/s][A[A

163it [01:04,  2.48it/s][A[A

164it [01:05,  2.51it/s][A[A

165it [01:05,  2.50it/s][A[A

166it [01:06,  2.51it/s][A[A

167it [01:06,  2.50it/s][A[A

168it [01:06,  2.49it/s][A[A

169it [01:07,  2.51it/s][A[A

170it [01:07,  2.50it/s][A[A

171it [01:08,  2.50it/s][A[A

Epoch: 0 (170/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -379.4833 (avg -496.4038)
	KL: 7.7251 (avg 8.9329)
	Loss: 379.5605 (avg 496.4932)




172it [01:08,  2.50it/s][A[A

173it [01:08,  2.49it/s][A[A

174it [01:09,  2.50it/s][A[A

175it [01:09,  2.49it/s][A[A

176it [01:10,  2.50it/s][A[A

177it [01:10,  2.50it/s][A[A

178it [01:10,  2.49it/s][A[A

179it [01:11,  2.52it/s][A[A

180it [01:11,  2.51it/s][A[A

181it [01:12,  2.50it/s][A[A

Epoch: 0 (180/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -433.4616 (avg -490.4096)
	KL: 8.0388 (avg 8.8840)
	Loss: 433.5420 (avg 490.4985)




182it [01:12,  2.51it/s][A[A

183it [01:12,  2.51it/s][A[A

184it [01:13,  2.53it/s][A[A

185it [01:13,  2.53it/s][A[A

186it [01:14,  2.54it/s][A[A

187it [01:14,  2.54it/s][A[A

188it [01:14,  2.54it/s][A[A

189it [01:15,  2.54it/s][A[A

190it [01:15,  2.53it/s][A[A

191it [01:16,  2.53it/s][A[A

Epoch: 0 (190/200)
	Data load time: 0.827 sec
	Generation time: 0.834 sec
	Backprop time: 1.042 sec
	Log probability: -331.2886 (avg -486.4684)
	KL: 7.9677 (avg 8.8357)
	Loss: 331.3683 (avg 486.5568)




192it [01:16,  2.52it/s][A[A

193it [01:16,  2.53it/s][A[A

194it [01:17,  2.51it/s][A[A

195it [01:17,  2.51it/s][A[A

196it [01:18,  2.51it/s][A[A

197it [01:18,  2.49it/s][A[A

198it [01:18,  2.51it/s][A[A

199it [01:19,  2.49it/s][A[A

200it [01:19,  2.78it/s][A[A

[A[A

0it [00:00, ?it/s][A[A

1it [00:00,  1.20it/s][A[A

2it [00:01,  1.57it/s][A[A

Epoch: 0 (0/11)
	Data load time: 0.773 sec
	Generation time: 0.779 sec
	Log probability: -310.9161 (avg -310.9161)
	KL: 7.1471 (avg 7.1471)
	Loss: 310.9876 (avg 310.9876)




3it [00:01,  2.03it/s][A[A

4it [00:01,  2.54it/s][A[A

5it [00:01,  2.80it/s][A[A

6it [00:01,  3.35it/s][A[A

7it [00:01,  3.90it/s][A[A

8it [00:02,  4.39it/s][A[A

9it [00:02,  4.74it/s][A[A

10it [00:02,  5.16it/s][A[A

11it [00:02,  4.39it/s][A[A

Epoch: 0 (10/11)
	Data load time: 0.773 sec
	Generation time: 0.779 sec
	Log probability: -299.8336 (avg -340.8479)
	KL: 7.1276 (avg 8.3076)
	Loss: 299.9049 (avg 340.9310)
Lowest validation loss: inf
Saving checkpoint..


IsADirectoryError: [Errno 21] Is a directory: 'models/190124'

### ADD VISUALIZATIONS

In [None]:
    # TODO include in epoch loop?
    # TODO look into what the scheduler is for
    # visualize reconst and free sample
    print("Plotting example imgs...")
    with torch.no_grad():
        
        val_iter = iter(val_dl)

        # reconstruct 25 imgs
        imgs = val_iter._get_batch()[1][0][:25]
        if args.cuda:
            imgs = imgs.cuda()
        imgs_reconst, mu, logvar = model(imgs)

        # sample 25 imgs
        noises = torch.randn(25, model.nz, 1, 1)
        if args.cuda:
            noises = noises.cuda()
        samples = model.decode(noises)

        def write_image(tag, images):
            """
            write the resulting imgs to tensorboard.
            :param tag: The tag for tensorboard
            :param images: the torch tensor with range (-1, 1). [9, 3, 256, 256]
            """
            # make it from 0 to 255
            images = (images + 1) / 2
            grid = make_grid(images, nrow=5, padding=20)
            writer.add_image(tag, grid.detach(), global_step=epoch + 1)

        write_image("origin", imgs)
        write_image("reconst", imgs_reconst)
        write_image("samples", samples)
        print('done')

# Test the Model

In [None]:
"""
The script for doing outlier detection using different score
"""
import argparse
from model import VAE
from loss import VAELoss
from dataloader import load_vae_test_datasets, load_vae_train_datasets
import os
import torch
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from itertools import product
import numpy as np
import pandas as pd
import scipy.stats as stats

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', required=True, type=str)
parser.add_argument('--data', required=True, type=str)
parser.add_argument('--image_size', default=256, type=int)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--kl_weight', type=float, default=1,
                    help="weight on KL term")
parser.add_argument('--out_csv', default='result.csv')
args = parser.parse_args()

# load checkpoint
if not os.path.isfile(args.model_path):
    print('%s is not path to a file' % args.model_path)
    exit()
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
print("checkpoint loaded!")
print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'], checkpoint['epoch']))

# model and criterion
model = VAE(args.image_size)
model.load_state_dict(checkpoint['state_dict'])
criterion = VAELoss(size_average=True, kl_weight=args.kl_weight)

if args.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

# load data
test_loader = load_vae_test_datasets(args.image_size, args.data)

############################# ANOMALY SCORE DEF ##########################
def get_vae_score(vae, image, L=5):
    """
    The vae score for a single image, which is basically the loss
    :param image: [1, 3, 256, 256]
    :return (vae loss, KL, reconst_err)
    """
    image_batch = image.expand(L,
                               image.size(1),
                               image.size(2),
                               image.size(3))
    reconst_batch, mu, logvar = vae.forward(image_batch)
    vae_loss, loss_details = criterion(reconst_batch, image_batch, mu, logvar)
    return vae_loss, loss_details['KL'], -loss_details['reconst_logp']

def _log_mean_exp(x, dim):
    """
    A numerical stable version of log(mean(exp(x)))
    :param x: The input
    :param dim: The dimension along which to take mean with
    """
    # m [dim1, 1]
    m, _ = torch.max(x, dim=dim, keepdim=True)

    # x0 [dm1, dim2]
    x0 = x - m

    # m [dim1]
    m = m.squeeze(dim)

    return m + torch.log(torch.mean(torch.exp(x0),
                                    dim=dim))

def get_iwae_score(vae, image, L=5):
    """
    The vae score for a single image, which is basically the loss
    :param image: [1, 3, 256, 256]
    :return scocre: (iwae score, iwae KL, iwae reconst).
    """
    # [L, 3, 256, 256]
    image_batch = image.expand(L,
                               image.size(1),
                               image.size(2),
                               image.size(3))

    # [L, z_dim, 1, 1]
    mu, logvar = vae.encode(image_batch)
    eps = torch.randn_like(mu)
    z = mu + eps * torch.exp(0.5 * logvar)
    kl_weight = criterion.kl_weight
    # [L, 3, 256, 256]
    reconst = vae.decode(z)
    # [L]
    log_p_x_z = -torch.sum((reconst - image_batch).pow(2).reshape(L, -1),
                          dim=1)

    # [L]
    log_p_z = -torch.sum(z.pow(2).reshape(L, -1), dim=1)

    # [L]
    log_q_z = -torch.sum(eps.pow(2).reshape(L, -1), dim=1)

    iwae_score = -_log_mean_exp(log_p_x_z + (log_p_z - log_q_z)*kl_weight, dim=0)
    iwae_KL_score = -_log_mean_exp(log_p_z - log_q_z, dim=0)
    iwae_reconst_score = -_log_mean_exp(log_p_x_z, dim=0)

    return iwae_score, iwae_KL_score, iwae_reconst_score

############################# END OF ANOMALY SCORE ###########################

# Define the number of samples of each score
def compute_all_scores(vae, image):
    """
    Given an image compute all anomaly score
    return (reconst_score, vae_score, iwae_score)
    """
    vae_loss, KL, reconst_err = get_vae_score(vae, image=image, L=15)
    iwae_loss, iwae_KL, iwae_reconst = get_iwae_score(vae, image, L=15)
    result = {'reconst_score': reconst_err.item(),
              'KL_score': KL.item(),
              'vae_score': vae_loss.item(),
              'iwae_score': iwae_loss.item(),
              'iwae_KL_score': iwae_KL.item(),
              'iwae_reconst_score': iwae_reconst.item()}
    return result


# MAIN LOOP
score_names = ['reconst_score', 'KL_score', 'vae_score',
               'iwae_reconst_score', 'iwae_KL_score', 'iwae_score']
classes = test_loader.dataset.classes
scores = {(score_name, cls): [] for (score_name, cls) in product(score_names,
                                                                 classes)}
model.eval()
with torch.no_grad():
    for idx, (image, target) in tqdm(enumerate(test_loader)):
        cls = classes[target.item()]
        if args.cuda:
            image = image.cuda()

        score = compute_all_scores(vae=model, image=image)
        for name in score_names:
            scores[(name, cls)].append(score[name])

# display the mean of scores
means = np.zeros([len(score_names), len(classes)])
for (name, cls) in product(score_names, classes):
    means[score_names.index(name), classes.index(cls)] = sum(scores[(name, cls)]) / len(scores[(name, cls)])
df_mean = pd.DataFrame(means, index=score_names, columns=classes)
print("###################### MEANS #####################")
print(df_mean)


classes.remove('NV')
auc_result = np.zeros([len(score_names), len(classes) + 1])
# get auc roc for each class
for (name, cls) in product(score_names, classes):
    normal_scores = scores[(name, 'NV')]
    abnormal_scores = scores[(name, cls)]
    y_true = [0]*len(normal_scores) + [1]*len(abnormal_scores)
    y_score = normal_scores + abnormal_scores
    auc_result[score_names.index(name), classes.index(cls)] = roc_auc_score(y_true, y_score)

# add auc roc against all diseases
for name in score_names:
    normal_scores = scores[(name, 'NV')]
    abnormal_scores = np.concatenate([scores[(name, cls)]for cls in classes]).tolist()
    y_true = [0]*len(normal_scores) + [1]*len(abnormal_scores)
    y_score = normal_scores + abnormal_scores
    auc_result[score_names.index(name), -1] = roc_auc_score(y_true, y_score)

df = pd.DataFrame(auc_result, index=score_names, columns=classes + ['ALL'])
# display
print("###################### AUC ROC #####################")
print(df)
print("####################################################")
df.to_csv(args.out_csv)

# fit a gamma distribution
_, val_loader = load_vae_train_datasets(args.image_size, args.data, 32)
model.eval()
all_reconst_err = []
num_val = len(val_loader.dataset)
with torch.no_grad():
    for img, _ in tqdm(val_loader):
        if args.cuda:
            img = img.cuda()

        # compute output
        recon_batch, mu, logvar = model(img)
        loss, loss_details = criterion.forward_without_reduce(recon_batch, img, mu, logvar)
        reconst_err = -loss_details['reconst_logp']
        all_reconst_err += reconst_err.tolist()

fit_alpha, fit_loc, fit_beta=stats.gamma.fit(all_reconst_err)

# using gamma for outlier detection
# get auc roc for each class
LARGE_NUMBER = 1e30

def get_gamma_score(scores):
    result = -stats.gamma.logpdf(scores, fit_alpha, fit_loc, fit_beta)
    # replace inf in result with largest number
    result[result == np.inf] = LARGE_NUMBER
    return result

auc_gamma_result = np.zeros([1, len(classes)+1])
name = 'reconst_score'
for cls in classes:
    normal_scores = get_gamma_score(scores[(name, 'NV')]).tolist()
    abnormal_scores = get_gamma_score(scores[(name, cls)]).tolist()
    y_true = [0]*len(normal_scores) + [1]*len(abnormal_scores)
    y_score = normal_scores + abnormal_scores
    auc_gamma_result[0, classes.index(cls)] = roc_auc_score(y_true, y_score)

# for all class
normal_scores = get_gamma_score(scores[(name, 'NV')]).tolist()
abnormal_scores = np.concatenate([get_gamma_score(scores[(name, cls)]) for cls in classes]).tolist()
y_true = [0]*len(normal_scores) + [1]*len(abnormal_scores)
y_score = normal_scores + abnormal_scores
auc_gamma_result[0, -1] = roc_auc_score(y_true, y_score)
df = pd.DataFrame(auc_gamma_result, index=['gamma score'], columns=classes + ['ALL'])

# display
print("###################### AUC ROC GAMMA #####################")
print(df)
print("##########################################################")
