In [None]:
import argparse
import numpy as np
import os
import typing

from PIL import Image

from  functools import reduce

import torch
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from torch.distributions.normal import Normal

from matplotlib import pyplot as plt

# uncomment 2 lines below for local test
import os, sys
sys.path.insert(0, os.path.join("..", ".."))

from deel.datasets import load as load_dataset



# Arguments for the script

In [None]:

parser = argparse.ArgumentParser(description='VAE MVTEC Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=40, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')

parser.add_argument('--img-shape', type=tuple, default=(28,28))
parser.add_argument('--analytic_kl', action='store_true')
parser.add_argument('--h_dim', type=int, default=200)
parser.add_argument('--z_dim', type=int, default=50)

args = parser.parse_args("")
args.cuda = not args.no_cuda and torch.cuda.is_available()


# Variables

In [None]:
torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

image_size = (3, 128, 128)
loaders_kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

figsize = (16,8)

main_class = 'bottle'
other_ood_classes = ['screw', 'leather']

z_size = 20

# number of samples to produce or reconstruct
num_samples = 16


res_dir = 'results'
if not os.path.isdir(res_dir):
    os.makedirs(res_dir, exist_ok=True)


# Load 'bottle' dataset from MVTec anomaly detection as ID and add 'screw' as other OOD 

In [None]:

# Load dataset splits for PyTorch
dataset = load_dataset(
    "mvtec-ad",
    mode="pytorch",
    split_by_class=True,
    image_size=image_size[1:],
)

loader_ID_train = DataLoader(dataset[main_class]['train'][0], batch_size=args.batch_size, shuffle=True, **loaders_kwargs)
loader_ID_valid = DataLoader(dataset[main_class]['test'][0], batch_size=args.batch_size, shuffle=False, **loaders_kwargs)
loader_OOD      = DataLoader(dataset[main_class]['unknown'][0], batch_size=args.batch_size, shuffle=False, **loaders_kwargs)

test_dataset_dict = {
    'MVTEC ID train': loader_ID_train,
    'MVTEC ID test': loader_ID_valid,
    'MVTEC OOD': loader_OOD,
}

# Additional OOD dataset
for cls in other_ood_classes:
    test_dataset_dict['OOD_{}'.format(cls)] = DataLoader(
        dataset[cls]['test'][0], 
        batch_size=args.batch_size, 
        shuffle=False, 
        **loaders_kwargs
    )


In [None]:
dataset['bottle']
#.keys()

# simple VAE model

In [None]:

class VAE(nn.Module):
    def __init__(self, img_size=(28,28), z_size=20, features=[400]):
        super(VAE, self).__init__()
        self.img_size = img_size
        self.input_size = reduce(lambda x,y: x*y, img_size)
        self.z_size = z_size

        features_in = [self.input_size] + features
        features_out = features
        layers = []
        for f_in, f_out in zip(features_in, features_out):
            layers.append(nn.Linear(f_in, f_out))
            layers.append(nn.ReLU())
        self.fc1 = nn.Sequential(*layers)

        self.fc21 = nn.Linear(features[-1], self.z_size)
        self.fc22 = nn.Linear(features[-1], self.z_size)

        features_in = [z_size] + features[::-1]
        features_out = features[::-1]
        layers = []
        for f_in, f_out in zip(features_in, features_out):
            layers.append(nn.Linear(f_in, f_out))
            layers.append(nn.ReLU())
        self.fc3 = nn.Sequential(*layers)

        self.fc4 = nn.Linear(features[0], self.input_size)

    def encode(self, x):
        h1 = self.fc1(x)
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3)).view(*((-1,) + self.img_size))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def sample(self, z):
        return z


In [None]:

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, reduction='sum'):
    BCE = F.binary_cross_entropy(recon_x, x, reduction=reduction)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())

    if reduction == 'sum':
       KLD = KLD.sum() 
    elif reduction == 'mean':
       KLD = KLD.mean() 

    if reduction == 'none':
       #print(reduction, 'BCE', BCE.shape, 'KLD', KLD.shape)
       while len(BCE.shape) > 1:
        BCE = BCE.sum(axis=1) 
       while len(KLD.shape) > 1:
        KLD = KLD.sum(axis=1) 

    #print(reduction, 'BCE', BCE.shape, 'KLD', KLD.shape)
    return BCE + KLD


def train(epoch, loader):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(loader.dataset),
                100. * batch_idx / len(loader),
                loss.item() / len(data)))

    train_loss /= len(loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss))
    return train_loss


def test(epoch, res_dir=None, loader=loader_ID_valid):
    model.eval()
    test_losses = []
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar, reduction='none')
            test_losses.extend(loss.cpu().detach().numpy().tolist())
            if i == 0 and res_dir:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch[:n]])
                save_image(comparison.cpu(),
                         os.path.join(res_dir, 'reconstruction_{}.png'.format(epoch)), nrow=n)

    test_loss = sum(test_losses) / len(test_losses)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss

def reconstruct(label, res_dir=None, loader=loader_ID_valid):
    model.eval()
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                    recon_batch[:n]
                                    ])
            save_image(comparison.cpu(),
                        os.path.join(res_dir, 'reconstruction_{}.png'.format(label)), nrow=n)
            break


def eval(loader):
    model.eval()
    eval_loss = []
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar, reduction='none')
            eval_loss.extend(loss.cpu().detach().numpy().tolist())
    return eval_loss



# Create and train model

In [None]:

model = VAE(image_size, z_size).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:

num_samples = 10
res_dir = 'results'
ood_res_dir = os.path.join(res_dir, 'ood')

for directory in [res_dir, ood_res_dir]:
    if not os.path.isdir(directory):
        os.makedirs(directory, exist_ok=True)

test_losses = {k:[] for k in test_dataset_dict}
train_losses = []
for epoch in range(1, args.epochs + 1):
    train_losses.append(train(epoch, loader=loader_ID_train))
    for k in test_losses: 
        test_losses[k].append(test(epoch, res_dir=None, loader=test_dataset_dict[k]))

    with torch.no_grad():
        sample = torch.randn(*((num_samples,) + (z_size,))).to(device)
        sample_dist = mods.path.join(res_dir, 'sample_{}.png'.format(epoch)))

plt.figure()
plt.plot(train_losses, label='train')
for label, losses in test_losses.items():
    plt.plot(losses, label=label)
plt.legend()
plt.show()


# Plot losses 

In [None]:

plt.figure(figsize=figsize)
#plt.plot(train_losses, label='train')
for label, losses in test_losses.items():
    plt.plot(losses, label=label)
plt.legend()
plt.show()

# do reconstruction on 

In [None]:
for label, loader in test_dataset_dict.items(): 
    ood_res_dir = os.path.join(res_dir, label)
    if not os.path.isdir(ood_res_dir):
        os.makedirs(ood_res_dir)
    reconstruct(label, res_dir=ood_res_dir, loader=loader)


# Plot evaluation histograms

In [None]:

evaluations = {label: eval(loader) for  label,loader in test_dataset_dict.items()}


In [None]:
plt.figure(figsize=figsize)
alpha_delta = 0.8
alpha = 1.0
for label, evaluation in evaluations.items():
    weights = np.ones_like(evaluation)/float(len(evaluation))
    plt.hist(evaluation, weights=weights, label=label, alpha=alpha)
    alpha *= alpha_delta

plt.legend()
plt.show()


In [None]:
plt.figure(figsize=figsize)
alpha_delta = 0.8
alpha = 1.0
for label, evaluation in evaluations.items():
    if "MVTEC" in label:
        weights = np.ones_like(evaluation)/float(len(evaluation))
        plt.hist(evaluation, weights=weights, label=label, alpha=alpha)
        alpha *= alpha_delta

plt.legend()
plt.show()
