<a href="https://colab.research.google.com/github/cimuletz/zero-shot-ad/blob/main/Zero_Shot_AD_ACR_BCE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports


In [None]:
import torch 
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torch.optim as optim
from PIL import Image
from tqdm import tqdm
from typing import Iterator, List, Callable, Tuple
import matplotlib.pyplot as plt
import urllib.request
from sklearn.metrics import roc_auc_score
from google.colab import drive
import numpy as np
import sys
torch.manual_seed(115)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ACR-BCE

In [None]:
def maml_init_(module):
    torch.nn.init.xavier_uniform_(module.weight.data, gain=1.0)
    torch.nn.init.constant_(module.bias.data, 0.0)
    return module


In [None]:
#maybe use affine?
class BasicBlock(torch.nn.Module):
  def __init__(self,
               in_channels, 
               out_channels,
               kernel_size,
               stride,
               max_pool=True):
    super().__init__()

    self.max_pool = max_pool
    self.max_pool_layer = torch.nn.MaxPool2d(kernel_size, stride)
    self.batch_norm = nn.BatchNorm2d(out_channels)
    self.activation_fn = nn.ReLU()

    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, bias=True)

    #maml initialization of the weights
    maml_init_(self.conv)

  def forward(self, x):
      x = self.conv(x)
      x = self.batch_norm(x)
      x = self.activation_fn(x)
      if self.max_pool == True:
        x = self.max_pool_layer(x)
      return x

In [None]:
class ACRBCE(nn.Module):
  def __init__(self):
        super().__init__()
        #batch_sz x 28 x 28
        self.conv1 = BasicBlock(1, 64, 3, 1)
        #batch_sz x 24 x 24
        self.conv2 = BasicBlock(64, 64, 3, 1)
        #batch_sz x 20 x 20
        self.conv3 = BasicBlock(64, 64, 3, 1)
        #batch_sz x 16 x 16
        self.conv4 = BasicBlock(64, 64, 3, 1)
        #batch_sz x 12 x 12
        self.fc1 = nn.Linear(in_features = 64 * 12 * 12, out_features = 1)

        # self.fc1.weight.data.normal_()
        # self.fc1.bias.data.mul_(0.0)

        self.activation_fn = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

  def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.sigmoid(self.fc1(x))

        return x 

# Data processing

In [None]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

In [None]:
!pip install matplotlib

In [None]:
#Creating the Pj distributions, j=0...5

# Full MNIST dataset
mnist = datasets.MNIST('./', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor(),]))

mnist_test = datasets.MNIST('./', train=False, download=True,
                   transform=transforms.Compose([transforms.ToTensor(),]))

# Dataset split by class

train_data = mnist.data.float() / 255
train_targets = mnist.targets

test_data = mnist_test.data.float() / 255
test_targets = mnist_test.targets

datasets_by_class = [[(train_data[idx], i) for idx in range(len(train_data)) if train_targets[idx] == i] for i in range(10)]
datasets_by_class_test = [[(test_data[idx], i) for idx in range(len(test_data)) if test_targets[idx] == i] for i in range(10)]

In [None]:
print(datasets_by_class[1][0][0].shape)

torch.Size([28, 28])


In [None]:
# this function returns a dataloader (or two: <training/validation>, if training mode is set True)
# which loads data having a certain label, mixed 
# with data having different labels, the anomalous data being 
# in the given ratio

import random
import copy

def get_training_data(training, label, anomaly_ratio, batch_size):
  if training:
    dataset = copy.deepcopy(datasets_by_class[label])
  else:
    dataset = copy.deepcopy(datasets_by_class_test[label])

  for i in range(len(dataset)):
    dataset[i] = (dataset[i][0], torch.tensor([0], dtype=torch.float32))

  anomaly_target_count = int(len(dataset) * anomaly_ratio)

  max_anomaly_class = 9

  for i in range(anomaly_target_count):
    target_class = random.randint(6, max_anomaly_class)

    # randomly picking anomaly label, must be different from given label
    while target_class == label:
      target_class = random.randint(6, max_anomaly_class)

    # picking a random example from there

    if training:
      random_index = random.randint(0, len(datasets_by_class[target_class]) - 1)
      random_example = datasets_by_class[target_class][random_index]
    else:
      random_index = random.randint(0, len(datasets_by_class_test[target_class]) - 1)
      random_example = datasets_by_class_test[target_class][random_index]

    random_example = (random_example[0], torch.tensor([1], dtype=torch.float32))
    dataset.append(random_example)
  
  random.shuffle(dataset)

  if training == 2:
    return dataset

  if training:
    # returning two dataloaders if train flag is set True
    split_index = int(len(dataset) / 5)

    test_dataset = dataset[:split_index]
    training_dataset = dataset[split_index:]

    training_dataloader = torch.utils.data.DataLoader(dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True, 
                                                      drop_last=True)
    
    testing_dataloader = torch.utils.data.DataLoader(dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    return training_dataloader, testing_dataloader
  else:  
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    return dataloader

# Trainer

In [None]:
class ACRTrainer:
  def __init__(self, model, loss_fn, no_classes, max_iter = 100):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"

    self.model = model.to(self.device)
    self.loss_fn = loss_fn
    self.max_iter = max_iter
    self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
    #self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9)
    self.no_classes = no_classes

  # singular step in training
  def train_step_distributions(self, loaders):
    self.model.train()
    loss = 0
    losses = []
    total_exp = 0
    total_loss = 0
    total = 0

    # one epoch passing through all the datasets
    for j in range(0, self.no_classes):
        loader = loaders[j][0]
        i = 0
        loss = 0
        for _, (x_train, y_train) in enumerate(loader):
          exp_num = len(x_train)
          total_exp += exp_num
          i += 1

          x_train = x_train.view(x_train.shape[0], 1, 28, 28)
          x = x_train.to(self.device)
          y = y_train.to(self.device)
          out = self.model(x)
          loss += self.loss_fn(out, y)

          # do backprop every 100 batches, otherwise running out of memory
          if i == 100:
              total_loss += loss.item()
              loss /= 100
              self.optimizer.zero_grad()
              loss.backward()
              self.optimizer.step()
              i = 0
              loss = 0
        if i > 0:
          total_loss += loss.item()
          loss /= i
          self.optimizer.zero_grad()
          loss.backward()
          self.optimizer.step()  
          losses.append(loss.item())
          total_loss += loss.item()

    return total_loss/total_exp, losses

  #validation/test function which calculates auc and loss
  def outlier_score(self, loader, mode='val'):
    auc = []
    loss = 0
    num_examples = 0
    pred = []
    truth = []
    i = 0
    self.model.eval()
    with torch.no_grad():
      for _, (x_train, y_train) in enumerate(loader):
        num_examples += len(x_train)
        x_train = x_train.view(x_train.shape[0], 1, 28, 28)
        x = x_train.to(self.device)
        y = y_train.to(self.device)

        out = self.model(x)
        aux_loss = self.loss_fn(out, y)
        loss += aux_loss.item()
        pred.extend(np.array(out.cpu()))
        truth.extend(y_train.cpu().numpy())
        auc_i = roc_auc_score(y_train.cpu().numpy(), np.array(out.cpu()))
        auc.append(auc_i)
    return loss/num_examples, roc_auc_score(truth, pred)

  # starts the training routine
  def train_distributions(self, loaders):
    val_auc, val_f1 = -1, -1
    test_auc, test_f1, test_score = None, None,None
    losses = []
    val_losses = []
    for i in range(0, self.max_iter):
      train_loss, losses1 = self.train_step_distributions(loaders)
      losses.append(train_loss)
      if i % 1 == 0:
        total_val_loss = 0
        for j in range(0, self.no_classes):
          mode = 0
          if j == 1:
            mode = 'ones'
          else:
            mode = 'val'
          val_loader = loaders[j][1]
          val_loss, val_auc = self.outlier_score(val_loader, mode)
          print(f"Iteration: {i}, Class:{j},  TL: {train_loss}, VL:{val_loss}, VA:{val_auc}")
          total_val_loss += val_loss
        val_losses.append(total_val_loss/self.no_classes)
      print()
    return losses, val_losses

  # starts the testing routine
  def test(self, loader):
    test_score, test_auc = 0, 0
    self.model.eval()
    test_score, test_auc = self.outlier_score(loader, 'test')
    return test_score, test_auc



# Run ACR-BCE

In [None]:
# ratios of anomalous examples to normal examples
test_pi = 0.1 
train_pi = 0.8

# batch sizes
test_sz = 64
train_sz = 64

# training iterations and classes used
no_classes_training = 5
no_iterations = 10

In [None]:
test_loader = get_training_data(0, 5, test_pi, test_sz)

# initialize training dataloaders
loaders = []
for i in range(0, no_classes_training):
  train_loader, val_loader = get_training_data(1, i, train_pi, train_sz)
  loaders.append((train_loader, val_loader))
test_loader = iter(test_loader)

# train model
model = ACRBCE()
loss = nn.BCELoss()
trainer = ACRTrainer(model, loss, no_classes_training, no_iterations)

# calculate test auc
losses, val_losses = trainer.train_distributions(loaders)
test_score, test_auc = trainer.test(test_loader)
print(f'Test score: {test_score}, Test AUC: {test_auc}')

In [None]:
%matplotlib inline
plt.plot(range(0,len(losses)), losses, 'g', label='Training loss')
plt.plot(range(0,len(losses)), val_losses, 'b', label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()


# Distance between datasets using MMD

In [None]:
def MMD(x, y, kernel):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))
    
    dxx = rx.t() + rx - 2. * xx # Used for A in (1)
    dyy = ry.t() + ry - 2. * yy # Used for B in (1)
    dxy = rx.t() + ry - 2. * zz # Used for C in (1)
    
    XX, YY, XY = (torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device))
    
    if kernel == "multiscale":
        
        bandwidth_range = [0.2, 0.5, 0.9, 1.3]
        for a in bandwidth_range:
            XX += a**2 * (a**2 + dxx)**-1
            YY += a**2 * (a**2 + dyy)**-1
            XY += a**2 * (a**2 + dxy)**-1
            
    if kernel == "rbf":
      
        bandwidth_range = [10, 15, 20, 50]
        for a in bandwidth_range:
            XX += torch.exp(-0.5*dxx/a)
            YY += torch.exp(-0.5*dyy/a)
            XY += torch.exp(-0.5*dxy/a)
      
      

    return torch.mean(XX + YY - 2. * XY)

In [None]:
for (i,j) in [(0, 1.0), (0, 0.5), (0, 0.1)]:
  for (k, l) in [(5, 1.0), (5, 0.5), (5, 0.1)]:
    dataset1 = torch.cat([train for train, target in (get_training_data(2, i, j, 1)[:400])]).to(device)
    dataset2 = torch.cat([train for train, target in (get_training_data(2, k, l, 1)[:400])]).to(device)
    res = (MMD(dataset1, dataset2, 'multiscale'))
    print(f'Dataset {i}, ratio {j}; dataset {k}, ratio {l} result: {res}')

Dataset 0, ratio 1.0; dataset 5, ratio 1.0 result: 0.01399888563901186
Dataset 0, ratio 1.0; dataset 5, ratio 0.5 result: 0.015456795692443848
Dataset 0, ratio 1.0; dataset 5, ratio 0.1 result: 0.01581188477575779
Dataset 0, ratio 0.5; dataset 5, ratio 1.0 result: 0.016917800530791283
Dataset 0, ratio 0.5; dataset 5, ratio 0.5 result: 0.019765865057706833
Dataset 0, ratio 0.5; dataset 5, ratio 0.1 result: 0.020520687103271484
Dataset 0, ratio 0.1; dataset 5, ratio 1.0 result: 0.03980853781104088
Dataset 0, ratio 0.1; dataset 5, ratio 0.5 result: 0.03733474761247635
Dataset 0, ratio 0.1; dataset 5, ratio 0.1 result: 0.04190833494067192
