# Adversarial Attacks are Reversible via Natural Supervision

Evaluation of _Adversarial Attacks are Reversible via Natural Supervision_.

To run this on a local runtime:
```
pip install jupyter_http_over_ws
jupyter serverextension enable --py jupyter_http_over_ws
pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension
jupyter notebook \
  --NotebookApp.allow_origin='https://colab.research.google.com' \
  --port=8888 \
  --NotebookApp.port_retries=0
# Install python packages as you see fit
```

## Setup (run once)

In [None]:
!pip install git+https://github.com/fra31/auto-attack

In [None]:
!wget https://cv.cs.columbia.edu/mcz/ICCVRevAttack/cifar10_rst_adv.pt.ckpt
!wget https://cv.cs.columbia.edu/mcz/ICCVRevAttack/ssl_model_130.pth

In [None]:
!git clone https://github.com/cvlab-columbia/SelfSupDefense

## Load models

In [None]:
import sys
sys.path.insert(0, 'SelfSupDefense')

In [None]:
import torch
from learning.unlabel_WRN import WideResNet_2
from learning.wideresnet import WRN34_out_branch
from utils import *

core_model = WideResNet_2(depth=28, widen_factor=10)
contrastive_head_model = WRN34_out_branch()

tmp = torch.load('cifar10_rst_adv.pt.ckpt', map_location=device)['state_dict']
new_tmp = {k[len('module.'):]: v for k, v in tmp.items()}
core_model.load_state_dict(new_tmp)

tmp = torch.load('ssl_model_130.pth', map_location=device)['ssl_model']
new_tmp = {k[len('module.'):]: v for k, v in tmp.items()}
contrastive_head_model.load_state_dict(new_tmp)

if torch.cuda.is_available():
  core_model = core_model.cuda()
  contrastive_head_model = contrastive_head_model.cuda()
core_model.eval()
contrastive_head_model.eval()

## Data

In [None]:
import numpy as np

data_dir = './data'  #@param {type: 'string'}
batch_size = 50  #@param {type: 'integer'}

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)


class Batches():
  def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.set_random_choices = set_random_choices
    self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last)

  def __iter__(self):
    if self.set_random_choices:
        self.dataset.set_random_choices()
    return ({'input': x.to(device).float(), 'target': y.to(device).long()} for (x,y) in self.dataloader)

  def __len__(self):
    return len(self.dataloader)


transforms = [Crop(32, 32), FlipLR()]
dataset = cifar10(data_dir)
train_set = list(zip(transpose(pad(dataset['train']['data'], 4) / 255.), dataset['train']['labels']))
train_set_x = Transform(train_set, transforms)
train_batches = Batches(train_set_x, batch_size, shuffle=True, set_random_choices=True, num_workers=2)
test_set = list(zip(transpose(dataset['test']['data'] / 255.), dataset['test']['labels']))
test_batches = Batches(test_set, batch_size, shuffle=False, num_workers=2)

## Define constrastive loss and full model

In [None]:
from torchvision.transforms import transforms
import torch.nn.functional as F

t = torch.nn.Sequential(
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.8, 0.8, 0.8, 0.2),
    transforms.RandomGrayscale(p=0.2))

scripted_transforms = torch.jit.script(t)
criterion = torch.nn.CrossEntropyLoss()


def _contrastive_loss(embeddings, batch_size, num_views):
  features = F.normalize(embeddings, dim=1)
  labels = torch.cat([torch.arange(batch_size) for i in range(num_views)], dim=0)
  labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
  labels = labels.cuda()
  similarity_matrix = torch.matmul(features, features.T)
  mask = torch.eye(labels.shape[0], dtype=torch.bool).cuda()
  labels = labels[~mask].view(labels.shape[0], -1)
  similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
  positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
  negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
  logits = torch.cat([positives, negatives], dim=1)
  labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
  temperature = 0.2
  logits = logits / temperature
  xcontrast_loss = criterion(logits, labels)
  correct = (logits.max(1)[1] == labels).sum().item()
  return xcontrast_loss, correct


def contrastive_loss(x, num_views=2, deterministic=False):
  # Make things deterministic.
  if deterministic:
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic=True

  assert num_views in (2, 4)
  xs_transformed = []
  # First 2 views.
  xs_transformed.append(scripted_transforms(x))
  xs_transformed.append(scripted_transforms(x))
  if num_views == 4:
    xs_transformed.append(scripted_transforms(x))
    xs_transformed.append(scripted_transforms(x))
  x_constrastive = torch.cat(xs_transformed, dim=0)
  _, out = core_model(x_constrastive)
  embeddings = contrastive_head_model(out)
  closs, acc = _contrastive_loss(embeddings, x.size(0), num_views)
  return closs

In [None]:
# Original code runs this once (due to a bug in PyTorch ?!).
# for i, batch in enumerate(train_batches):
#   x = batch['input']
#   _ = contrastive_loss(x)
#   break

# This is the original implementation (not clear why not use torch.clamp).
# def clamp(X, lower_limit, upper_limit):
#   return torch.max(torch.min(X, upper_limit), lower_limit)
  

def repair_inputs(inputs, epsilon=16/255, alpha=2/255, num_steps=40, num_views=2, deterministic=False):
  with torch.enable_grad():
    delta = torch.zeros_like(inputs).cuda()
    # The original resamples delta randomly.
    # delta.uniform_(-epsilon, epsilon)
    # delta = clamp(delta, -inputs, 1-inputs)
    delta.requires_grad_()
    for _ in range(num_steps):
      new_x = inputs + delta
      loss = -contrastive_loss(new_x, num_views=num_views, deterministic=deterministic)
      grad = torch.autograd.grad(loss, delta, retain_graph=True, create_graph=False)[0]
      delta = delta + alpha * torch.sign(grad)
      delta = torch.clamp(delta, min=-epsilon, max=epsilon)
      delta = torch.clamp(inputs + delta, min=0, max=1) - inputs
      # Original code below. The above code should do the same but allow for AutoAttack to work.
      # loss.backward()
      # grad = delta.grad.detach()
      # delta.data = clamp(torch.clamp(delta + alpha * torch.sign(grad), min=-epsilon, max=epsilon), -inputs, 1-inputs)
      # delta.grad.zero_()
  return delta.detach() + inputs


def defended_model(x, epsilon=16/255, alpha=2/255, num_steps=40, num_views=2, deterministic=False):
  rep_x = repair_inputs(x, epsilon=epsilon, alpha=alpha, num_steps=num_steps, num_views=num_views, deterministic=deterministic)
  output, _ = core_model(rep_x)
  return output


def defended_model_bpda(x, epsilon=16/255, alpha=2/255, num_steps=40, num_views=2, deterministic=False):
    z = x.clone().detach()
    with torch.no_grad():
        delta = repair_inputs(x, epsilon=epsilon, alpha=alpha, num_steps=num_steps, num_views=num_views, deterministic=deterministic) - z
    output, _ = core_model(x + delta)
    return output


def undefended_model(x):
  output, _ = core_model(x)
  return output

## Evaluation (of Semi-SL model)

In [None]:
# Clean accuracy

import time
from tqdm import tqdm

t = time.time()
total_correct = 0.
total_count = 0
for i, batch in tqdm(enumerate(test_batches)):
  x, y = batch['input'], batch['target']
  output = undefended_model(x)
  total_correct += (output.max(1)[1] == y).sum().item()
  total_count += y.size(0)
  torch.cuda.empty_cache()
  if i == 5:
    break
t = time.time() - t

print(f'Accuracy: {100.*total_correct/total_count:.2f}%')
print(f'Time: {t}[s]')

In [None]:
from autoattack import AutoAttack

base_adversary = AutoAttack(undefended_model, norm='Linf', eps=8. / 255, verbose=True)
base_adversary.attacks_to_run = ['apgd-ce', 'apgd-dlr']
base_adversary.apgd.n_restarts = 1
base_adversary.apgd.n_iter = 10

# Base model.
total_correct = 0.
total_count = 0
for i, batch in tqdm(enumerate(test_batches)):
  x, y = batch['input'], batch['target']
  adv_autoattack, adv_labels = base_adversary.run_standard_evaluation(x, y, bs=x.shape[0], return_labels=True)
  break

## SelfSup model

In [None]:
# Clean accuracy

from tqdm import tqdm

t = time.time()
total_correct = 0.
total_count = 0
for i, batch in tqdm(enumerate(test_batches)):
  x, y = batch['input'], batch['target']
  output = defended_model(x, epsilon=16/255, alpha=8/255, num_steps=2, num_views=2)
  total_correct += (output.max(1)[1] == y).sum().item()
  total_count += y.size(0)
  torch.cuda.empty_cache()
  if i == 5:
    break
t = time.time() - t

print(f'Accuracy: {100.*total_correct/total_count:.2f}%')
print(f'Time: {t}[s]')

In [None]:
# Transfer.

In [None]:
import functools
from autoattack import AutoAttack

attack_num_steps = 10
attack_epsilon = 8/255

defense_epsilon = 16/255
defense_alpha = 2/255
defense_num_steps = 10
defense_num_views = 2

base_adversary = AutoAttack(undefended_model, norm='Linf', eps=attack_epsilon, verbose=False)
base_adversary.attacks_to_run = ['apgd-ce', 'apgd-dlr']
base_adversary.apgd.n_restarts = 1
base_adversary.apgd.n_iter = attack_num_steps

deterministic_adversary = AutoAttack(functools.partial(defended_model_bpda, epsilon=defense_epsilon, alpha=defense_alpha, num_steps=defense_num_steps, num_views=defense_num_views, deterministic=True), norm='Linf', eps=attack_epsilon, verbose=False)
deterministic_adversary.attacks_to_run = ['apgd-ce', 'apgd-dlr']
deterministic_adversary.apgd.n_restarts = 1
deterministic_adversary.apgd.n_iter = attack_num_steps

random_adversary = AutoAttack(functools.partial(defended_model_bpda, epsilon=defense_epsilon, alpha=defense_alpha, num_steps=defense_num_steps, num_views=defense_num_views, deterministic=False), norm='Linf', eps=attack_epsilon, verbose=False)
random_adversary.attacks_to_run = ['apgd-ce', 'apgd-dlr']
random_adversary.apgd.n_restarts = 1
random_adversary.apgd.n_iter = attack_num_steps
random_adversary.apgd.eot_iter = 20

# Wrapped model.
clean_total_correct = 0
base_total_correct = 0
transfer_total_correct = 0
deterministic_total_correct = 0
random_total_correct = 0
total_count = 0
num_examples = 1000

for i, batch in tqdm(enumerate(test_batches)):
  x, y = batch['input'], batch['target']
  total_count += y.size(0)

  # Clean performance.
  output = defended_model(x, epsilon=defense_epsilon, alpha=defense_alpha, num_steps=defense_num_steps, num_views=defense_num_views)
  clean_total_correct += (output.max(1)[1] == y).sum().item()

  # Base model.
  adv_autoattack, adv_labels = base_adversary.run_standard_evaluation(x, y, bs=y.size(0), return_labels=True)
  base_total_correct += (adv_labels == y).sum().item()

  # Transfer.
  output = defended_model(adv_autoattack, epsilon=defense_epsilon, alpha=defense_alpha, num_steps=defense_num_steps, num_views=defense_num_views)
  transfer_total_correct += (output.max(1)[1] == y).sum().item()

  # Deterministic model.
  _, adv_labels = deterministic_adversary.run_standard_evaluation(x, y, bs=y.size(0), return_labels=True)
  deterministic_total_correct += (adv_labels == y).sum().item()

  # Random model.
  _, adv_labels = random_adversary.run_standard_evaluation(x, y, bs=y.size(0), return_labels=True)
  random_total_correct += (adv_labels == y).sum().item()

  torch.cuda.empty_cache()

  print(f"\n\nClean accuracy: {100.*clean_total_correct/total_count:.2f}%")
  print(f"Robust accuracy (base): {100.*base_total_correct/total_count:.2f}%")
  print(f"Robust accuracy (transfer): {100.*transfer_total_correct/total_count:.2f}%")
  print(f"Robust accuracy (full, deterministic): {100.*deterministic_total_correct/total_count:.2f}%")
  print(f"Robust accuracy (full, random): {100.*random_total_correct/total_count:.2f}%")

  if total_count >= num_examples:
    break

print("\n\nFINAL:")
print(f"Clean accuracy: {100.*clean_total_correct/total_count:.2f}%")
print(f"Robust accuracy (base): {100.*base_total_correct/total_count:.2f}%")
print(f"Robust accuracy (transfer): {100.*transfer_total_correct/total_count:.2f}%")
print(f"Robust accuracy (full, deterministic): {100.*deterministic_total_correct/total_count:.2f}%")
print(f"Robust accuracy (full, random): {100.*random_total_correct/total_count:.2f}%")