Evaluation of _Improving Model Robustness with Latent Distribution Locally and Globally_.

To run this on a local runtime:
```
pip install jupyter_http_over_ws
jupyter serverextension enable --py jupyter_http_over_ws
pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension
jupyter notebook \
  --NotebookApp.allow_origin='https://colab.research.google.com' \
  --port=8888 \
  --NotebookApp.port_retries=0
# Install python packages as you see fit
```

## Setup (run once)

In [None]:
!git clone https://github.com/LitterQ/ATLD-pytorch
!pip install git+https://github.com/fra31/auto-attack
!pip install "foolbox<3"
!pip install adversarial-robustness-toolbox
import sys
sys.path.insert(0,'ATLD-pytorch/cifar10')

In [None]:
!gdown --id 18NOtz_z29iMKdv92xTkXhZLVeCvg0N_o

## Imports

In [None]:
from __future__ import print_function
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import sys
import datetime

from tqdm import tqdm
from models_new.wideresnet import *
from models_new.dis import *
import utils

## Build models and dataset

In [None]:
batch_size = 50  #@param {type: 'integer'}

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
basic_net = WideResNet(depth=28, num_classes=10, widen_factor=10)
basic_net = basic_net.to(device)
discriminator = Discriminator_2(depth=28,num_classes=1,widen_factor=5).to(device)
print(f'Using device: {device}')

In [None]:
transform_test = transforms.Compose([transforms.ToTensor()])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
checkpoint = torch.load('latest', map_location=torch.device(device))
basic_net_params = {}
for k, v in checkpoint['net'].items():
  if k.startswith('basic_net.'):
    basic_net_params[k[len('basic_net.'):]] = v
basic_net.load_state_dict(basic_net_params)
print('Classifier loaded')

disc_params = {}
for k, v in checkpoint['net'].items():
  if k.startswith('discriminator.'):
    disc_params[k[len('discriminator.'):]] = v
discriminator.load_state_dict(disc_params)
print('Discriminator loaded')

In [None]:
basic_net.eval()
discriminator.eval();

In [None]:
from torch.autograd import Variable
import numpy as np

def atld(classifier, discriminator, inputs, epsilon=8. / 255):
  inputs = inputs * 2. - 1.
  
  # IMT
  adversarial_criterion = nn.BCELoss()
  valid = torch.Tensor(np.ones((inputs.size(0), 1)))
  if device == 'cuda':
    valid = valid.cuda()

  with torch.enable_grad():
    delta = Variable(torch.zeros_like(inputs), requires_grad=True)
    if device == 'cuda':
      delta = delta.cuda()
    logits = classifier(inputs + delta)[0]
    logits_reshaped = torch.reshape(logits, [inputs.size(0), 10, 1, 1])
    logits_disc, _ = discriminator(logits_reshaped)
    adv_loss = adversarial_criterion(logits_disc, valid)
    adv_loss.backward(retain_graph=True)
    grad = delta.grad.detach()

  # Mask logits.
  mask = torch.Tensor(np.ones((inputs.size(0), 1, 1, 1))).cuda()
  for idx, v in enumerate(logits_disc):
    if 0.3 < v < 0.7:
      mask[idx] = 0.5
  mask = mask.expand_as(inputs)
  
  inputs_repaired = inputs - epsilon * 2. * mask  * torch.sign(grad)
  inputs_repaired = torch.clamp(inputs_repaired, -1., 1.)
  outputs = classifier(inputs_repaired)[0]
  return outputs, None, None

# Accuracy

In this section, we measure the nominal accuracy. The numbers mostly match the numbers reported in the paper. For ATLD+, there is a difference of 0.01-0.02% (which might be due to numerical errors). It's unclear where the randomness comes from.

We expect 93.34% for ATLD- and 90.78% for ATLD+.

## Vanilla (ATLD-)

In [None]:
basic_net.eval()
correct = 0
total = 0

def logits_fn(x):
  return basic_net(x * 2. - 1.)[0]

t = time.time()
iterator = tqdm(testloader, ncols=0, leave=False)
for batch_idx, (inputs, targets) in enumerate(iterator):
  inputs, targets = inputs.to(device), targets.to(device)
  outputs = logits_fn(inputs)
  _, predicted = torch.max(outputs.data, 1)
  total += targets.size(0)
  correct += (predicted == targets).sum().item()
print(f'Time: {time.time() - t}')
acc = 100. * correct / total
print('Accuracy:', acc)

In [None]:
#@title FGSM (77.26%, expected 73.58%)

from art.attacks.evasion import ProjectedGradientDescentPyTorch
from art.estimators.classification import PyTorchClassifier

class IdentityModule(nn.Module):
  """Simple Torch wrapper needed by the ART library."""

  def __init__(self, logits_fn):
    super(IdentityModule, self).__init__()
    self.logits_fn = logits_fn

  def forward(self, x):
    return self.logits_fn(x)


def logits_fn(x):
  return basic_net(x * 2. - 1.)[0]


classifier = PyTorchClassifier(
    model=IdentityModule(logits_fn),
    clip_values=(0, 1),
    loss=nn.CrossEntropyLoss(),
    input_shape=(3, 32, 32),
    nb_classes=10)

attack = ProjectedGradientDescentPyTorch(estimator=classifier, eps=8/255, eps_step=2/255, max_iter=20, batch_size=batch_size)

iterator = tqdm(testloader, ncols=0, leave=False)
for batch_idx, (inputs, targets) in enumerate(iterator):
  inputs, targets = inputs.to(device), targets.to(device)
  x_adv = attack.generate(x=inputs.cpu().numpy())
  outputs = logits_fn(torch.Tensor(x_adv).cuda())
  _, predicted = torch.max(outputs.data, 1)
  total += targets.size(0)
  correct += (predicted == targets).sum().item()

print(f'PGD20 accuracy\t{correct * 100. / total:.2f}%')  # 77.26%

## ATLD+

In [None]:
def logits_fn(x):
  return atld(basic_net, discriminator, x)[0]


basic_net.eval()
correct = 0
total = 0

t = time.time()
iterator = tqdm(testloader, ncols=0, leave=False)
for batch_idx, (inputs, targets) in enumerate(iterator):
  inputs, targets = inputs.to(device), targets.to(device)
  outputs = logits_fn(inputs)
  _, predicted = torch.max(outputs.data, 1)
  total += targets.size(0)
  correct += (predicted == targets).sum().item()
print(f'Time: {time.time() - t}')
acc = 100. * correct / total
print('Accuracy:', acc)

# Robust accuracy (1 batch)

In [None]:
for inputs, targets in testloader:
  inputs, targets = inputs.to(device), targets.to(device)
  break

## ATLD-

In [None]:
def logits_fn(x):
  return basic_net(x * 2. - 1.)[0]

# Nominal accuracy.
outputs = logits_fn(inputs)
_, predicted = torch.max(outputs.data, 1)
total = targets.size(0)
correct = (predicted == targets).sum().item()
print(f'nominal accuracy\t{correct * 100. / total:.2f}%')

from autoattack import AutoAttack
adversary = AutoAttack(logits_fn, norm='Linf', eps=8. / 255, verbose=True)
adversary.attacks_to_run = ['apgd-ce', 'apgd-t']
adv_autoattack, adv_labels = adversary.run_standard_evaluation(inputs, targets, bs=inputs.shape[0], return_labels=True)

## ATLD

In [None]:
def logits_fn(x):
  return atld(basic_net, discriminator, x)[0]

# Nominal accuracy.
outputs = logits_fn(inputs)
_, predicted = torch.max(outputs.data, 1)
total = targets.size(0)
correct = (predicted == targets).sum().item()
print(f'nominal accuracy\t{correct * 100. / total:.2f}%')

from autoattack import AutoAttack
adversary = AutoAttack(logits_fn, norm='Linf', eps=8. / 255, verbose=True)
adversary.attacks_to_run = ['apgd-ce', 'apgd-t']
adv_autoattack, adv_labels_atld = adversary.run_standard_evaluation(inputs, targets, bs=inputs.shape[0], return_labels=True)