# Imports

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os, math
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

# Define a VAE Model
Initially designed for 2D input images.
Based on this paper: https://arxiv.org/abs/1807.01349

In [5]:
depth = 64      # initial depth to convolve channels into
n_channels = 3  # number of channels (RGB)
filt_size = 4   # convolution filter size
stride = 2      # stride for conv
pad = 1         # padding added for conv

class VAE2D(nn.Module):
    def __init__(self, img_size, n_latent=300):
        
        # Model setup
        #############
        super(VAE2D, self).__init__()
        self.n_latent = n_latent
        n = math.log2(img_size)
        assert n == round(n), 'Image size must be a power of 2'  # restrict image input sizes permitted
        assert n >= 3, 'Image size must be at least 8'           # low dimensional data won't work well
        n = int(n)

        # Encoder - first half of VAE
        #############################
        self.encoder = nn.Sequential()  
        # input: n_channels x img_size x img_size
        # ouput: depth x conv_img_size^2
        # conv_img_size = (img_size - filt_size + 2 * pad) / stride + 1
        self.encoder.add_module('input-conv', nn.Conv2d(n_channels, depth, filt_size, stride, pad,
                                                        bias=True))
        self.encoder.add_module('input-relu', nn.ReLU(inplace=True))
        
        # Add conv layer for each power of 2 over 3 (min size)
        # Pyramid strategy with batch normalization added
        for i in range(n - 3):
            # input: depth x conv_img_size^2
            # output: o_depth x conv_img_size^2
            # i_depth = o_depth of previous layer
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i + 1)
            self.encoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.Conv2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.encoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.encoder.add_module(f'pyramid_{o_depth}_relu',
                                    nn.ReLU(inplace=True))
        
        # Latent representation
        #######################
        # Convolve the encoded image into the latent space, once for mu and once for logvar
        max_depth = depth * 2 ** (n - 3)
        self.conv_mu = nn.Conv2d(max_depth, n_latent, filt_size)      # return the mean of the latent space 
        self.conv_logvar = nn.Conv2d(max_depth, n_latent, filt_size)  # return the log variance of the same
        
        
        # Decoder - second half of VAE
        ##############################
        self.decoder = nn.Sequential()
        # input: max_depth x conv_img_size^2 (8 x 8)  TODO double check sizes
        # output: n_latent x conv_img_size^2 (8 x 8)
        # default stride=1, pad=0 for this layer
        self.decoder.add_module('input-conv', nn.ConvTranspose2d(n_latent, max_depth, filt_size, bias=True))
        self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(max_depth))
        self.decoder.add_module('input-relu', nn.ReLU(inplace=True))
    
        # Reverse the convolution pyramids used in the encoder
        for i in range(n - 3, 0, -1):
            i_depth = depth * 2 ** i
            o_depth = depth * 2 ** (i - 1)
            self.decoder.add_module(f'pyramid_{i_depth}-{o_depth}_conv',
                                    nn.ConvTranspose2d(i_depth, o_depth, filt_size, stride, pad, bias=True))
            self.decoder.add_module(f'pyramid_{o_depth}_batchnorm',
                                    nn.BatchNorm2d(o_depth))
            self.decoder.add_module(f'pyramid_{o_depth}_relu', nn.ReLU(inplace=True))
        
        # Final transposed convolution to return to img_size
        # Final activation is tanh instead of relu to allow negative pixel output
        self.decoder.add_module('output-conv', nn.ConvTranspose2d(depth, n_channels,
                                                                  filt_size, stride, pad, bias=True))
        self.decoder.add_module('output-tanh', nn.Tanh())

        # Model weights init
        ####################
        # Randomly initialize the model weights using kaiming method
        # Reference: "Delving deep into rectifiers: Surpassing human-level
        # performance on ImageNet classification" - He, K. et al. (2015)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def encode(self, imgs):
        """
        Encode the images into latent space vectors (mean and log variance representation)
        input:  imgs   [batch_size, 3, 256, 256]
        output: mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        """
        output = self.encoder(imgs)
        output = output.squeeze(-1).squeeze(-1)
        return [self.conv_mu(output), self.conv_logvar(output)]

    def generate(self, mu, logvar):
        """
        Generates a random latent vector using the trained mean and log variance representation
        input:  mu     [batch_size, n_latent, 1, 1]
                logvar [batch_size, n_latent, 1, 1]
        output: gen    [batch_size, n_latent, 1, 1]
        """
        std = torch.exp(0.5 * logvar)
        gen = torch.randn_like(std)
        return gen.mul(std).add_(mu)

    def decode(self, gen):
        """
        Restores an image representation from the generated latent vector
        input:  gen      [batch_size, n_latent, 1, 1]
        output: gen_imgs [batch_size, 3, 256, 256]
        """
        return self.decoder(gen)

    def forward(self, imgs):
        """
        Generates reconstituted images from input images based on learned representation
        input: imgs     [batch_size, 3, 256, 256]
        ouput: gen_imgs [batch_size, 3, 256, 256]
               mu       [batch_size, n_latent]
               logvar   [batch_size, n_latent]
        """
        mu, logvar = self.encode(imgs)
        gen = self.generate(mu, logvar)
        return (self.decode(gen),
                mu.squeeze(-1).squeeze(-1),
                logvar.squeeze(-1).squeeze(-1))


In [6]:
model = VAE2D(256)

In [7]:
model

VAE2D(
  (encoder): Sequential(
    (input-conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (input-relu): ReLU(inplace)
    (pyramid_64-128_conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_128_batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_128_relu): ReLU(inplace)
    (pyramid_128-256_conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_256_batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_256_relu): ReLU(inplace)
    (pyramid_256-512_conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_512_batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid_512_relu): ReLU(inplace)
    (pyramid_512-1024_conv): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (pyramid_1024_batchno

# Define a loss function
Must be suitable for anomaly detection by recreation similarity

In [9]:
#TODO

In [None]:
class VAE2DLoss(nn.Module):
    """
            This criterion is an implementation of VAELoss
    """

    def __init__(self, size_average=False, kl_weight=1):
        super(VAE2DLoss, self).__init__()
        self.size_average = size_average
        self.kl_weight = kl_weight

    def forward(self, recon_x, x, mu, logvar):
        """
        :param recon_x: generating images. [bsz, C, H, W]
        :param x: origin images. [bsz, C, H, W]
        :param mu: latent mean. [bsz, z_dim]
        :param logvar: latent log variance. [bsz, z_dim]
        :return loss, loss_details.
            loss: a scalar. negative of elbo
            loss_details: {'KL': KL, 'reconst_logp': -reconst_err}
        """
        bsz = x.shape[0]
        reconst_err = (x - recon_x).pow(2).reshape(bsz, -1)
        reconst_err = 0.5 * torch.sum(reconst_err, dim=-1)

        # KL(q || p) = -log_sigma + sigma^2/2 + mu^2/2 - 1/2
        KL = (-logvar + logvar.exp() + mu.pow(2) - 1) * 0.5
        KL = torch.sum(KL, dim=-1)
        if self.size_average:
            KL = torch.mean(KL)
            reconst_err = torch.mean(reconst_err)
        else:
            KL = torch.sum(KL)
            reconst_err = torch.sum(reconst_err)
        loss = reconst_err + self.kl_weight * KL
        return loss, {'KL': KL, 'reconst_logp': -reconst_err}

    def forward_without_reduce(self, recon_x, x, mu, logvar):
        """
        This also compute the vae loss but it's without take mean or take sum
        :param recon_x: generating images. [bsz, C, H, W]
        :param x: origin images. [bsz, C, H, W]
        :param mu: latent mean. [bsz, z_dim]
        :param logvar: latent log variance. [bsz, z_dim]
        :return: losses. [bsz] and loss details
        """
        bsz = x.shape[0]
        reconst_err = (x - recon_x).pow(2).reshape(bsz, -1)
        reconst_err = 0.5 * torch.sum(reconst_err, dim=-1)

        # KL(q || p) = -log_sigma + sigma^2/2 + mu^2/2 - 1/2
        KL = (-logvar + logvar.exp() + mu.pow(2) - 1) * 0.5
        KL = torch.sum(KL, dim=-1)

        # [bsz]
        losses = reconst_err + self.kl_weight * KL
        return losses, {'KL': KL, 'reconst_logp': -reconst_err}


# Define utility functions

In [None]:
import time
from tqdm import tqdm


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum / self.count


def trainVAE(train_loader, model, criterion, optimizer, epoch, args):
    """
    Iterate through the train data and perform optimization
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_avg = AverageMeter()
    kl_avg = AverageMeter()
    reconst_logp_avg = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.cuda:
            input = input.cuda()

        recon_batch, mu, logvar = model(input)
        loss, loss_details = criterion(recon_batch, input, mu, logvar)

        # record loss
        loss_avg.update(loss.item(), input.size(0))
        kl_avg.update(loss_details['KL'].item(), input.size(0))
        reconst_logp_avg.update(loss_details['reconst_logp'].item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'reconst_logp {reconst_logp_avg.val:.4f} ({reconst_logp_avg.avg:.4f})\t'
                  'kl {kl_avg.val:.4f} ({kl_avg.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, reconst_logp_avg=reconst_logp_avg, kl_avg=kl_avg,
                   loss=loss_avg))

    return loss_avg.avg, kl_avg.avg, reconst_logp_avg.avg


def validateVAE(val_loader, model, criterion, args):
    """
    iterate through the validate set and output the accuracy
    """
    batch_time = AverageMeter()
    loss_avg = AverageMeter()
    kl_avg = AverageMeter()
    reconst_logp_avg = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, _) in enumerate(val_loader):
        if args.cuda:
            input = input.cuda()

        # compute output
        recon_batch, mu, logvar = model(input)
        loss, loss_details = criterion(recon_batch, input, mu, logvar)

        # measure accuracy and record loss
        loss_avg.update(loss.item(), input.size(0))
        kl_avg.update(loss_details['KL'].item(), input.size(0))
        reconst_logp_avg.update(loss_details['reconst_logp'].item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'reconst_logp {reconst_logp_avg.val:.4f} ({reconst_logp_avg.avg:.4f})\t'
                  'kl {kl_avg.val:.4f} ({kl_avg.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   i, len(val_loader), batch_time=batch_time, reconst_logp_avg=reconst_logp_avg,
                   kl_avg=kl_avg, loss=loss_avg))
    return loss_avg.avg, kl_avg.avg, reconst_logp_avg.avg

def evaluateVAE(test_loader, model, criterion, args):
    """
    iterate through test loader and find out average loss of normal and
    abnormal
    """
    avg_abnormal_loss = AverageMeter()
    avg_normal_loss = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in tqdm(enumerate(test_loader)):
       if args.cuda:
           input = input.cuda()

       # compute output
       recon_batch, mu, logvar = model(input)
       loss, loss_details = criterion(recon_batch, input, mu, logvar)

       # if normal
       if target.item() == 1:
           avg_normal_loss.update(loss.item(), input.size(0))
       else:
           avg_abnormal_loss.update(loss.item(), input.size(0))

    return avg_normal_loss.avg, avg_abnormal_loss.avg


In [None]:
def load_datasets(img_size, data_path):
    """
    Load the image datasets from vae_train and vae_test
    Transform to correct image size
    """
    
    train_path = data_path / 'vae_train/train/'
    val_path = data_path / 'vae_train/val/'
    test_path = data_path / 'vae_test/'
    
    norm_args = {'mean': [0.5] * n_channels,
                 'std': [0.5] * n_channels}
    jitter_args = {'brightness': 0.1,
                   'contrast': 0.1
                   'saturation': 0.1}  # hue unchanged
    
    train_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomCrop(img_size), # var in horizontal position
        transforms.RandomHorizontalFlip(p=0.25),  # var in photo orientation
        transforms.RandomVerticalFlip(p=0.25),
        transforms.ColorJitter(**jitter_args),  # var in photo lighting
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])
    
    test_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size) # assume center is most important
        transforms.ToTensor(),
        transforms.Normalize(**norm_args)])

    train_ds = datasets.ImageFolder(train_path, train_transform)
    val_ds = datasets.ImageFolder(val_path, test_transform)
    test_ds = datasets.ImageFolder(test_path, test_transform)
    
    
    loader_args = {
                   'shuffle': True,
                   'num_workers': 4}
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, **loader_args)
    val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, **loader_args)
    test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=1, ** loader_args)
    
    return train_dl, val_dl, test_dl

# Load Data

In [None]:
# Model/Data
desc = 'VAE for detecting anomalies in 2D images'
data_path = Path('data/NV_outlier/')
img_size = 128
n_channels = 3

# Training
epochs = 40
lr = 1e-4                # learning rate
lr_decay = 0.1           # lr decay factor
kl = 0.01                # weight of the kl term
schedule = [10, 20, 30]  # decrease lr at these epochs
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Checkpoints/Logging
save_path = Path('model/')  # where model and results will be saved
load_path = None            # checkpoint to resume from (default None)
log_freq = 10               # print status after this many batches

In [None]:
train_dl, val_dl, test_dl = load_datasets(img_size, data_path)

In [None]:
print(train_dl.shape, val_dl.shape, test_dl.shape)

# Build the Model

In [None]:
# Create model
model = VAE2D(img_size)

# Load optimizer and scheduler
optim = torch.optim.Adam(params=model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, schedule, lr_decay)

# Load checkpoint if any
if load_path is not None:
    checkpoint = torch.load(load_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    optim.load_state_dict(checkpoint['optimizer'])
    print("Checkpoint loaded")
    print(f"Val loss: {checkpoint['val_loss']}\tEpoch: {checkpoint['epoch']}")

# Set loss criterion
criterion = VAE2DLoss(size_average=True, kl_weight=kl)

# Move to GPU
model = model.to(device)
criterion = criterion.to(device)

# Train the Model

In [None]:
# TODO replace with Visdom - from tensorboardX import SummaryWriter

In [None]:
# Make save directory
if save_path.is_dir():
    print(f"{save_path} already exists")
else:
    os.mkdir(save_path)

# TODO - continue
    
# save args
args_dict = vars(args)
with open(os.path.join(args.out_dir, 'config.txt'), 'w') as f:
    for k in args_dict.keys():
        f.write("{}:{}\n".format(k, args_dict[k]))
writer = SummaryWriter(log_dir=os.path.join(args.out_dir, 'logs'))

# main loop
best_loss = np.inf
for epoch in range(args.epochs):
    # train for one epoch
    scheduler.step()
    train_loss, train_kl, train_reconst_logp = trainVAE(train_loader, model, criterion, opt, epoch, args)
    writer.add_scalar('train_elbo', -train_loss, global_step=epoch + 1)
    writer.add_scalar('train_kl', train_kl, global_step=epoch + 1)
    writer.add_scalar('train_reconst_logp', train_reconst_logp, global_step=epoch + 1)

    # evaluate on validation set
    with torch.no_grad():
        val_loss, val_kl, val_reconst_logp = validateVAE(val_loader, model, criterion, args)
        writer.add_scalar('val_elbo', -val_loss, global_step=epoch + 1)
        writer.add_scalar('val_kl', val_kl, global_step=epoch + 1)
        writer.add_scalar('val_reconst_logp', val_reconst_logp, global_step=epoch + 1)

    # remember best acc and save checkpoint
    if val_loss < best_loss:
        print('checkpointed!')
        best_loss = val_loss
        save_dict = {'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'val_loss': val_loss,
                     'optimizer': opt.state_dict()}
        save_path = os.path.join(args.out_dir, 'best_model.pth.tar')
        torch.save(save_dict, save_path)
    print('curr lowest val loss {}'.format(best_loss))

    # visualize reconst and free sample
    print("plotting imgs...")
    with torch.no_grad():
        val_iter = val_loader.__iter__()

        # reconstruct 25 imgs
        imgs = val_iter._get_batch()[1][0][:25]
        if args.cuda:
            imgs = imgs.cuda()
        imgs_reconst, mu, logvar = model(imgs)

        # sample 25 imgs
        noises = torch.randn(25, model.nz, 1, 1)
        if args.cuda:
            noises = noises.cuda()
        samples = model.decode(noises)

        def write_image(tag, images):
            """
            write the resulting imgs to tensorboard.
            :param tag: The tag for tensorboard
            :param images: the torch tensor with range (-1, 1). [9, 3, 256, 256]
            """
            # make it from 0 to 255
            images = (images + 1) / 2
            grid = make_grid(images, nrow=5, padding=20)
            writer.add_image(tag, grid.detach(), global_step=epoch + 1)

        write_image("origin", imgs)
        write_image("reconst", imgs_reconst)
        write_image("samples", samples)
        print('done')

import ipdb

In [None]:
import time
from tqdm import tqdm


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum / self.count


def trainVAE(train_loader, model, criterion, optimizer, epoch, args):
    """
    Iterate through the train data and perform optimization
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_avg = AverageMeter()
    kl_avg = AverageMeter()
    reconst_logp_avg = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.cuda:
            input = input.cuda()

        recon_batch, mu, logvar = model(input)
        loss, loss_details = criterion(recon_batch, input, mu, logvar)

        # record loss
        loss_avg.update(loss.item(), input.size(0))
        kl_avg.update(loss_details['KL'].item(), input.size(0))
        reconst_logp_avg.update(loss_details['reconst_logp'].item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'reconst_logp {reconst_logp_avg.val:.4f} ({reconst_logp_avg.avg:.4f})\t'
                  'kl {kl_avg.val:.4f} ({kl_avg.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, reconst_logp_avg=reconst_logp_avg, kl_avg=kl_avg,
                   loss=loss_avg))

    return loss_avg.avg, kl_avg.avg, reconst_logp_avg.avg


def validateVAE(val_loader, model, criterion, args):
    """
    iterate through the validate set and output the accuracy
    """
    batch_time = AverageMeter()
    loss_avg = AverageMeter()
    kl_avg = AverageMeter()
    reconst_logp_avg = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, _) in enumerate(val_loader):
        if args.cuda:
            input = input.cuda()

        # compute output
        recon_batch, mu, logvar = model(input)
        loss, loss_details = criterion(recon_batch, input, mu, logvar)

        # measure accuracy and record loss
        loss_avg.update(loss.item(), input.size(0))
        kl_avg.update(loss_details['KL'].item(), input.size(0))
        reconst_logp_avg.update(loss_details['reconst_logp'].item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'reconst_logp {reconst_logp_avg.val:.4f} ({reconst_logp_avg.avg:.4f})\t'
                  'kl {kl_avg.val:.4f} ({kl_avg.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   i, len(val_loader), batch_time=batch_time, reconst_logp_avg=reconst_logp_avg,
                   kl_avg=kl_avg, loss=loss_avg))
    return loss_avg.avg, kl_avg.avg, reconst_logp_avg.avg

def evaluateVAE(test_loader, model, criterion, args):
    """
    iterate through test loader and find out average loss of normal and
    abnormal
    """
    avg_abnormal_loss = AverageMeter()
    avg_normal_loss = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in tqdm(enumerate(test_loader)):
       if args.cuda:
           input = input.cuda()

       # compute output
       recon_batch, mu, logvar = model(input)
       loss, loss_details = criterion(recon_batch, input, mu, logvar)

       # if normal
       if target.item() == 1:
           avg_normal_loss.update(loss.item(), input.size(0))
       else:
           avg_abnormal_loss.update(loss.item(), input.size(0))

    return avg_normal_loss.avg, avg_abnormal_loss.avg


