In [3]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torchvision
from torchvision import models
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import os

import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, FastICA
from sklearn.covariance import LedoitWolf, MinCovDet

from torchvision.transforms import transforms
from sklearn.metrics import roc_auc_score

np.random.seed(252525)
torch.manual_seed(252525)

import torch
import torch.nn as nn

from data.mvtec import *
from hugeica import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [42]:
import torch, time
import torch.nn as nn
from torch.distributions import TransformedDistribution, Uniform, SigmoidTransform, AffineTransform
import torch.nn.functional as F
import numpy as np
from ummon import *

def remove_model(model="_"):
    if os.path.exists(f'./{model}.pth.tar'):
        os.remove(f'./{model}.pth.tar')
    if os.path.exists(f'./{model}_best_training_loss.pth.tar'):
        os.remove(f'./{model}_best_training_loss.pth.tar')

class LogCosh(nn.Module):
    
    def __init__(self, n_dims=128, loc=torch.zeros(1), scale=torch.ones(1) * np.sqrt(3)/np.pi):
        super().__init__()
        self.register_buffer("loc", loc.view(-1).clone())
        self.register_buffer("scale", scale.view(-1).clone())
        self._initialize_distributions()
        self.device = "cpu"

    def to(self, device):
        super().to(device)
        self.device = device
    
    def _initialize_distributions(self):
        base_distribution = Uniform(0, 1)
        transforms = [SigmoidTransform().inv, AffineTransform(loc=self.loc, scale=self.scale)]
        self._distributions = TransformedDistribution(base_distribution, transforms)

    def sample(self, size=(1, 100)):
        return self._distributions.sample(size).to(self.device)
    
    def log_prob(self, X):
        X = X.to("cpu")
        return self._distributions.log_prob(X).sum(axis=1).to(self.device)

class Normal(nn.Module):
    
    def __init__(self, n_dims=128, loc=torch.zeros(128), scale=torch.ones(128)):
        super().__init__()
        if n_dims != loc.shape[0]:
            loc, scale =torch.zeros(n_dims), torch.ones(n_dims)
        self.loc, self.scale = loc, scale
        self._initialize_distributions()
        self.device = "cpu"

    def to(self, device):
        super().to(device)
        self.device = device
        self.loc = self.loc.to(device)
        self.scale = self.scale.to(device)
        self._distributions.loc = self._distributions.loc.to(device)
        self._distributions.scale = self._distributions.scale.to(device)
    
    def _initialize_distributions(self):
        self._distributions = torch.distributions.Normal(self.loc, self.scale)

    def sample(self, size=(1, 101)):
        return self._distributions.sample(size).to(self.device)
    
    def log_prob(self, X):
        return self._distributions.log_prob(X).sum(axis=1)

class Laplacian(nn.Module):
    
    def __init__(self, n_dims=128, loc=torch.zeros(128), scale=torch.ones(128)):
        super().__init__()
        if n_dims != loc.shape[0]:
            loc, scale  =torch.zeros(n_dims), torch.ones(n_dims)
        self.loc, self.scale = loc, scale
        self._initialize_distributions()
        self.device = "cpu"

    def to(self, device):
        super().to(device)
        self.device = device
    
    def _initialize_distributions(self):
        self._distributions = torch.distributions.laplace.Laplace(self.loc, self.scale)

    def sample(self, size=(1, 100)):
        return self._distributions.sample(size).to(self.device)
    
    def log_prob(self, X):
        X = X.to("cpu")
        return self._distributions.log_prob(X).sum(axis=1).to(self.device)

deep128 = {
             "nets": lambda: nn.Sequential(nn.Linear(128, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 128)),
             "nett": lambda: nn.Sequential(nn.Linear(128, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 128))
            }


isa128 = {
             "nets": lambda: nn.Sequential(nn.Linear(128, 128)),
             "nett": lambda: nn.Sequential(nn.Linear(128, 128))
            }


def random_mask(layers, n_dims):
    masks = np.zeros((layers, n_dims)).astype(np.float32)
    for i in range(layers):
        idx =  np.random.uniform(0, n_dims, np.max([1, n_dims//2])).astype(np.int32)
        masks[i, idx] = 1.
    masks = torch.from_numpy(masks)   
    return masks

def checkerboard(layers, n_dims):
    rows = cols = int(np.sqrt(n_dims))
    masks = np.zeros((layers, rows, cols)).astype(np.float32)
    for i in range(layers):
        for row in range(rows):
            for col in range(cols):
                masks[i, row, col] = 1. if (row % 2 == 0 and col % 2 == 0) or (row % 2 == 1 and col % 2 == 1) else 0.
    masks = torch.from_numpy(masks).view(layers, -1)   
    for i in range(len(masks)):
        if np.random.uniform() > 0.5: # invert mask with probability 0.5
            masks[i] = torch.abs(masks[i] - 1)
    return masks

def channelwise_checkerboard(layers, n_dims, fmaps=1, patch_size=32):
    rows = cols = patch_size
    masks = np.zeros((layers, fmaps, rows, cols)).astype(np.float32)
    for i in range(layers):
        for f in range(fmaps):
            for row in range(rows):
                for col in range(cols):
                    masks[i, f, row, col] = 1. if (row % 2 == 0 and col % 2 == 0) or (row % 2 == 1 and col % 2 == 1) else 0.
            if np.random.uniform() > 0.5: # invert mask with probability 0.5
                masks[i, f] = np.abs( masks[i, f] - 1)
    masks = torch.from_numpy(masks).view(layers, -1)   
    return masks

def channelwise_checkerboard_random(layers, n_dims, fmaps=1, patch_size=32, n=1):
    assert n_dims == fmaps*patch_size*patch_size
    mask = np.zeros((layers, n_dims)).astype(np.float32)
    block = (n+n+1)*(n+n+1)
    n_blocks = int((n_dims / 2) / (block)) # pixels of a single block
    stepsize_channel = patch_size**2
    for i in range(layers):
        current_mask = np.zeros(fmaps*patch_size*patch_size).reshape(fmaps, patch_size, patch_size).astype(np.float32)
        for b in range(n_blocks):
            fi = int(np.random.uniform(0, fmaps))          # choose fmap, always 0 if fmaps == 1
            xi = int(np.random.uniform(0+n, patch_size-n)) # choose xpos
            yi = int(np.random.uniform(0+n, patch_size-n)) # choose ypos
            current_mask[fi, yi-n:yi+n, xi-n:xi+n ] = 1.
        mask[i] = current_mask.flatten()
    mask = torch.from_numpy(mask)
    return mask

def channelwise(layers, n_dims, fmaps=1, patch_size=32):
    assert fmaps > 1
    assert n_dims == fmaps*patch_size*patch_size
    mask = np.zeros((layers, n_dims)).astype(np.float32)
    stepsize = patch_size**2
    for i in range(layers):
        idx = np.random.uniform(0, fmaps, fmaps//2).astype(np.int32)
        for ii in idx:
            mask[i, ii*stepsize:(ii+1)*stepsize] = 1.
    mask = torch.from_numpy(mask)
    return mask

class Reshape(nn.Module):
    
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        
    def forward(self, X):
        return X.view(len(X), *self.shape)

class BPD(OnlineMetric):

    def __init__(self, model):
        self.model = model

    def __call__(self, z, _):
        """
        See Also:
            Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
            - or -
            Page 12 in https://arxiv.org/pdf/1705.07057.pdf
        """
        n, dim = z.shape[0], z.shape[1]
        nll = -(self.model.prior.log_prob(z) - np.log(256)*dim  + self.model.log_det_J).mean()
        bpd = nll / (np.log(2) * dim)
        return bpd

class RealNVP(nn.Module):
    """
    Simple RealNVP
    https://github.com/bayesgroup/deepbayes-2019/blob/master/seminars/day3/nf/nf-assignment.ipynb

    """

    def __init__(self,nets = lambda: nn.Sequential(nn.Linear(128, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 128)),
                      nett = lambda: nn.Sequential(nn.Linear(128, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 128)),
                      mask = None, prior = None, distr="normal", layers=3, n_dims=128, patch_size=None, fmaps=None, logits=False):
        super().__init__()
        assert distr in ["logcosh" , "normal", "laplacian"]
        
        if mask is None:
            if patch_size is None:
                mask = random_mask(layers, n_dims)
            elif fmaps <= 3:
                mask = channelwise_checkerboard(layers, n_dims, fmaps=fmaps, patch_size=patch_size)
            else:
                mask = channelwise(layers, n_dims, fmaps=fmaps, patch_size=patch_size)

        if prior is None:
            if distr == "logcosh":
                prior = LogCosh(n_dims)
            if distr == "normal":
                prior = Normal(n_dims)
            if distr == "laplacian":
                prior = Laplacian(n_dims)

        self.logits = logits
        self.prior = prior
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.s = torch.nn.ModuleList([nets() for _ in range(len(mask))])
        self.t = torch.nn.ModuleList([nett() for _ in range(len(mask))])

    def to(self, device):
        self.prior.to(device)
        super().to(device)
        return self

    def g(self, z):
        x = z
        for i in range(len(self.t)):
            x_ = x * self.mask[i]
            s = self.s[i](x_) * (1 - self.mask[i])
            s = torch.tanh(s)
            t = self.t[i](x_)*(1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
        return x

    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z # ON nodes
            s = self.s[i](z_) * (1-self.mask[i]) # OFF nodes
            s = torch.tanh(s)
            t = self.t[i](z_) * (1-self.mask[i]) # OFF nodes
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J

    def predict(self, X, bs=100, log_prob=False):
        if log_prob:
            pred = lambda x: self.log_prob(torch.from_numpy(x).to(self.prior.device)).cpu().view(-1, 1).detach().numpy()[:,0]
        else:
            pred = lambda x: self.f(torch.from_numpy(x).to(self.prior.device))[0].cpu().detach().numpy()
        return np.concatenate([pred(X[i*bs:i*bs+bs]) for i in range(int(np.ceil(len(X) / bs)))], 0)

    def log_prob(self,x):
        return self.prior.log_prob(self(x)) + self.log_det_J

    
    def predict_logp(self, X, bs=100):
        device = next(self.parameters()).device
        zs = torch.cat([ self.forward(torch.from_numpy( X[i*bs:(i+1)*bs]).to(device)  ).detach() for i in range(0, len(X), bs)])
        logps = self.log_prob(zs).detach().cpu().numpy()
        return logps

    def to_logit(self, x):
        noise = torch.FloatTensor(np.random.uniform(size=x.shape)).to(x.device)
        data_constraint = torch.FloatTensor([0.9]).to(x.device)
        y = (x * 255. + noise) / 256.
        y = (2 * y - 1) * data_constraint
        y = (y + 1) / 2
        y = torch.log(y) - torch.log(1. - y)

        # Save log-determinant of Jacobian of initial transform
        ldj = F.softplus(y) + F.softplus(-y) \
            - F.softplus((1. - torch.log(data_constraint) - torch.log(data_constraint)))
        sldj = ldj.view(ldj.size(0), -1).sum(-1)
        return y, sldj

    def forward(self, x):
        sldj = 0.
        if self.logits:
            x, sldj = self.to_logit(x)
        z, log_det_J = self.f(x)
        self.log_det_J = sldj + log_det_J
        self.H = z
        return z

    def sample(self, batchSize, prior=None): 
        if prior is None:
            z = self.prior.sample((batchSize, 1))
        else:
            z = prior.sample((batchSize, 1))            
        logp = self.prior.log_prob(z)
        x = self.g(z[:, 0, :])
        return x.cpu().detach()

    def fit(self, X, epochs, lr=1e-4, bs=100, log_interval=10, validation_ratio=0, validation_set=None):

        if validation_ratio > 0.:
            assert validation_set is None
            idx = np.random.permutation(len(X))
            self.idx = idx
            num_train = int( X.shape[0] * (1. - validation_ratio) )
            X_val = X[idx[num_train:]]
            X = X[idx[:num_train]]
        else:
            X_val = validation_set

        loss = lambda z, _ : -(self.prior.log_prob(z) + self.log_det_J).sum() / (z.shape[0] * z.shape[1])

        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad==True], lr=lr)

        m_name = "./__cache__/" + str(time.time())
        remove_model(m_name)

        with Logger(loglevel=20, log_epoch_interval=log_interval) as lg:   
            trs = Trainingstate(m_name)
            scheduler = StepLR_earlystop(optimizer, trs, self, step_size=2000, nsteps=3, logger=lg, mode='min', gamma=0.1, patience=50)
            tt = KamikazeTrainer(lg, self, loss, optimizer, use_cuda=torch.cuda.is_available(), scheduler=scheduler, trainingstate=trs, convergence_eps=0.0000001, combined_training_epochs=0)
            tt.fit((X, bs), validation_set=X_val, epochs=epochs, metrics=[ BPD(self)])

        print("loading best model..")
        trs.maybe_load_best_available_model_(self)
        remove_model(m_name)
        return self


In [32]:
bs = 13
clazz = 13
epochs = 1
augment = False
layer = 4

net = models.efficientnet_b4(pretrained=True).features[:layer]
net = net.to(device)
net.eval()
        
X_, X_valid_, X_test_, X_labels_, T = zip(*[dataloader(clazz, P=224, s=224, label_per_patch=False, augment=augment) for i in range(epochs)])
X__, X_valid_, X_test_ = np.concatenate(X_), np.concatenate(X_valid_), np.concatenate(X_test_)

_, X_valid__, X_test__, X_labels_, T = dataloader(clazz, P=224, s=224, label_per_patch=False, augment=False)
    
with torch.no_grad():
    X_ = torch.cat([ net(torch.from_numpy( X__[i:i+bs] ).to(device) ).detach().cpu() for i in range(0, len(X__), bs)]).cpu().numpy()
    X_valid_ = torch.cat([ net(torch.from_numpy( X_valid__[i:i+bs] ).to(device) ).detach().cpu() for i in range(0, len(X_valid__), bs)]).cpu().numpy()
    X_test_ = torch.cat([ net(torch.from_numpy( X_test__[i:i+bs] ).to(device) ).detach().cpu() for i in range(0, len(X_test__), bs)]).cpu().numpy()
    
net = net.to("cpu")
X_.shape[1:]

(56, 28, 28)

In [33]:
shape = X_.shape[1:]

ndim = np.prod(shape)
complexity = shape[0]
fmaps = shape[0]
hmaps = fmaps 
p = X_.shape[2]

nets = lambda: nn.Sequential(Reshape(shape), nn.Conv2d(fmaps, hmaps, 3, padding=1), nn.LeakyReLU(), nn.Conv2d(hmaps, fmaps, 3, padding=1), nn.LeakyReLU(), Reshape((ndim,)))
nett = lambda: nn.Sequential(Reshape(shape), nn.Conv2d(fmaps, hmaps, 3, padding=1), nn.LeakyReLU(), nn.Conv2d(hmaps, fmaps, 3, padding=1), nn.LeakyReLU(), Reshape((ndim,)))
#nett = lambda: nn.Sequential(nn.Linear(ndim, complexity), nn.LeakyReLU(), nn.Linear(complexity, ndim))
            
flow = RealNVP(nets, nett, n_dims=ndim, patch_size=p, fmaps=fmaps, layers=8, mask=channelwise(8, ndim, fmaps=fmaps, patch_size=p))

In [43]:
res = flow.fit(X_.reshape(len(X_), -1), epochs=800, lr=1e-4, bs=50, log_interval=10, validation_ratio=0.1)

Scheduler: 3 learning rates decreased by factor 0.1 after 2000 epochs, early stopping after 50, min mode.
Begin training: 800 epochs.
Epoch: 10 - loss(trn/val):3.06742/3.07871, BPD(trn/val):12.43/12.44, lr=0.00010 [BEST]. [0s] @1928 samples/s 
Epoch: 20 - loss(trn/val):2.96675/2.98026, BPD(trn/val):12.28/12.30, lr=0.00010 [BEST]. [0s] @1920 samples/s 
Epoch: 30 - loss(trn/val):2.94733/2.96257, BPD(trn/val):12.25/12.27, lr=0.00010 [BEST]. [0s] @1924 samples/s 
Epoch: 40 - loss(trn/val):2.93690/2.95255, BPD(trn/val):12.24/12.26, lr=0.00010 [BEST]. [0s] @1925 samples/s 
Epoch: 50 - loss(trn/val):2.93412/2.95061, BPD(trn/val):12.23/12.26, lr=0.00010. [0s] @1924 samples/s 
Epoch: 60 - loss(trn/val):2.93091/2.94808, BPD(trn/val):12.23/12.25, lr=0.00010. [0s] @1918 samples/s 
Epoch: 70 - loss(trn/val):2.93000/2.94753, BPD(trn/val):12.23/12.25, lr=0.00010 [BEST]. [0s] @1910 samples/s 
Epoch: 80 - loss(trn/val):2.92923/2.94748, BPD(trn/val):12.23/12.25, lr=0.00010 [BEST]. [0s] @1917 samples/s 


loading best model..


In [45]:
scores_inliers = -flow.predict_logp(X_valid_.reshape(len(X_valid_), -1))
scores_outliers = -flow.predict_logp(X_test_.reshape(len(X_test_), -1))

auc = roc_auc_score([0] * len(scores_inliers) + [1] * len(scores_outliers), np.concatenate([scores_inliers, scores_outliers]))
print(MVTEC.CLASSES[clazz], auc) # augmentation on + shift 0.2

screw 0.5187804878048781


In [11]:
scores_inliers = -flow.predict_logp(X_valid_.reshape(len(X_valid_), -1))
scores_outliers = -flow.predict_logp(X_test_.reshape(len(X_test_), -1))

auc = roc_auc_score([0] * len(scores_inliers) + [1] * len(scores_outliers), np.concatenate([scores_inliers, scores_outliers]))
print(MVTEC.CLASSES[clazz], auc) # augmentation on

screw 0.8553999999999999


In [40]:
scores_inliers = -flow.predict_logp(X_valid_.reshape(len(X_valid_), -1))
scores_outliers = -flow.predict_logp(X_test_.reshape(len(X_test_), -1))

auc = roc_auc_score([0] * len(scores_inliers) + [1] * len(scores_outliers), np.concatenate([scores_inliers, scores_outliers]))
print(MVTEC.CLASSES[clazz], auc) # augmentation off

screw 0.9009756097560975
