# Part 2: Generating adversarial examples: first demonstrated in [Intriguing properties of neural networks](https://arxiv.org/pdf/1312.6199.pdf)

In [None]:
#@title Setup and imports

from typing import Mapping, Tuple, Optional, Sequence, Union

from absl import app
from absl import flags
from absl import logging

# We will use haiku on top of jax; it is not included by default, so let's install it  
!pip install -q dm-haiku
import haiku as hk

import jax
import jax.numpy as jnp  # equivalent on numpy on GPU and TPU
import numpy as np  # original numpy

# Dataset library
import tensorflow.compat.v2 as tf
#import tensorflow_datasets as tfds

# Plotting library.
from matplotlib import pyplot as plt

from urllib.request import urlopen
import pickle

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

random_seed = 42

### Hyper-parameters

In [None]:
# Dataset constants for cifar10 dataset:
# it contains low-res natural images (32x32x3) belonging to 10 classes. 
num_classes = 10

class_dict = [u'airplane', u'automobile', u'bird', u'cat', u'deer', u'dog', u'frog', u'horse', u'ship', u'truck']

def display_logits(logits: jnp.ndarray):
  softmax = np.exp(logits)/np.sum(np.exp(logits), axis=1, keepdims=True)
  plt.bar(range(len(class_dict)), softmax[0])
  plt.xticks(range(len(class_dict)), class_dict, rotation='vertical')
  plt.show()
  max_prob = float(np.max(softmax))
  max_cls = class_dict[np.argmax(softmax)]
  print("{:.3f}% confident that this is class {}.".format(max_prob*100,max_cls))

### Load the image to attack, and specify the target class.

In [None]:
image_url = "https://github.com/eemlcommunity/PracticalSessions2020/raw/master/assets/airplane.pkl" #@param

with urlopen(image_url) as f:
  image = pickle.load(f)

# Note: this image has already been preprocessed for the neural net.
plt.imshow((image+1.)/2.)

# the class that we'll trick the classifier into outputting for this image.
target_class = 'horse' #@param

In [None]:
#@title Define the model: Resnet18
def check_length(length, value, name):
  if len(value) != length:
    raise ValueError(f"`{name}` must be of length {length} not {len(value)}")

class BlockV2(hk.Module):
  """ResNet V2 block with optional bottleneck."""

  def __init__(
      self,
      channels: int,
      stride: Union[int, Sequence[int]],
      use_projection: bool,
      bn_config: Mapping[str, float],
      bottleneck: bool,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.use_projection = use_projection

    # define batch norm parameters
    bn_config = dict(bn_config)
    bn_config.setdefault("create_scale", True)
    bn_config.setdefault("create_offset", True)

    if self.use_projection:
      self.proj_conv = hk.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=stride,
          with_bias=False,
          padding="SAME",
          name="shortcut_conv")

    channel_div = 4 if bottleneck else 1
    conv_0 = hk.Conv2D(
        output_channels=channels // channel_div,
        kernel_shape=1 if bottleneck else 3,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_0")

    bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config)

    conv_1 = hk.Conv2D(
        output_channels=channels // channel_div,
        kernel_shape=3,
        stride=stride,
        with_bias=False,
        padding="SAME",
        name="conv_1")

    bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config)
    layers = ((conv_0, bn_0), (conv_1, bn_1))

    if bottleneck:
      conv_2 = hk.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=1,
          with_bias=False,
          padding="SAME",
          name="conv_2")

      bn_2 = hk.BatchNorm(name="batchnorm_2", **bn_config)
      layers = layers + ((conv_2, bn_2),)

    self.layers = layers

  def __call__(self, inputs, is_training, test_local_stats):
    x = shortcut = inputs

    for i, (conv_i, bn_i) in enumerate(self.layers):
      x = bn_i(x, is_training, test_local_stats)
      x = jax.nn.relu(x)
      if i == 0 and self.use_projection:
        shortcut = self.proj_conv(x)
      x = conv_i(x)

    return x + shortcut


class BlockGroup(hk.Module):
  """Group of blocks for ResNet implementation."""

  def __init__(
      self,
      channels: int,
      num_blocks: int,
      stride: Union[int, Sequence[int]],
      bn_config: Mapping[str, float],
      bottleneck: bool,
      use_projection: bool,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)

    self.blocks = []
    for i in range(num_blocks):
      self.blocks.append(
          BlockV2(channels=channels,
                  stride=(1 if i else stride),
                  use_projection=(i == 0 and use_projection),
                  bottleneck=bottleneck,
                  bn_config=bn_config,
                  name="block_%d" % (i)))

  def __call__(self, inputs, is_training, test_local_stats):
    out = inputs
    for block in self.blocks:
      out = block(out, is_training, test_local_stats)
    return out


class ResNet(hk.Module):
  """ResNet model."""

  BlockGroup = BlockGroup  # pylint: disable=invalid-name

  def __init__(
      self,
      blocks_per_group: Sequence[int],
      num_classes: int,
      bn_config: Optional[Mapping[str, float]] = None,
      bottleneck: bool = True,
      channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
      use_projection: Sequence[bool] = (True, True, True, True),
      name: Optional[str] = None,
  ):
    """Constructs a ResNet model.
    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
       bottleneck: Whether the block should bottleneck or not. Defaults to True.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      name: Name of the module.
    """
    super().__init__(name=name)
    bn_config = dict(bn_config or {})
    bn_config.setdefault("decay_rate", 0.9)
    bn_config.setdefault("eps", 1e-5)
    bn_config.setdefault("create_scale", True)
    bn_config.setdefault("create_offset", True)

    # Number of blocks in each group for ResNet.
    check_length(4, blocks_per_group, "blocks_per_group")
    check_length(4, channels_per_group, "channels_per_group")

    self.initial_conv = hk.Conv2D(
        output_channels=64,
        kernel_shape=7,
        stride=2,
        with_bias=False,
        padding="SAME",
        name="initial_conv")

    self.block_groups = []
    strides = (1, 2, 2, 2)
    for i in range(4):
      self.block_groups.append(
          BlockGroup(channels=channels_per_group[i],
                     num_blocks=blocks_per_group[i],
                     stride=strides[i],
                     bn_config=bn_config,
                     bottleneck=bottleneck,
                     use_projection=use_projection[i],
                     name="block_group_%d" % (i)))

    self.final_batchnorm = hk.BatchNorm(name="final_batchnorm", **bn_config)
    self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name="logits")

  def __call__(self, inputs, is_training, test_local_stats=False):
    out = inputs
    out = self.initial_conv(out)

    out = hk.max_pool(out,
                      window_shape=(1, 3, 3, 1),
                      strides=(1, 2, 2, 1),
                      padding="SAME")

    for block_group in self.block_groups:
      out = block_group(out, is_training, test_local_stats)

    out = self.final_batchnorm(out, is_training, test_local_stats)
    out = jax.nn.relu(out)
    out = jnp.mean(out, axis=[1, 2])
    return self.logits(out)


class ResNet18(ResNet):
  """ResNet18."""

  def __init__(self,
               num_classes: int,
               bn_config: Optional[Mapping[str, float]] = None,
               name: Optional[str] = None):
    """Constructs a ResNet model.
    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(2, 2, 2, 2),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     bottleneck=False,
                     channels_per_group=(64, 128, 256, 512),
                     use_projection=(False, True, True, True),
                     name=name)

### Create the forward pass of the model

In [None]:
def net_fn(
    batch: jnp.array,
    is_training: bool,
) -> jnp.ndarray:
  """Forward pass of the resnet."""
  model = ResNet18(num_classes, bn_config={'decay_rate': 1.0})
  return model(batch, is_training=is_training)

# Transform the forward function into a pair of pure functions.
# We use transform with state because we need to keep the state of the network,
# e.g. for batch norm statistics.
net = hk.transform_with_state(net_fn)

### Load the network parameters.

In [None]:
# These pickle files were created by pickling the 'params' and 'state' from the
# previous tutorial.
param_pkl = "https://github.com/eemlcommunity/PracticalSessions2020/raw/master/assets/checkpoint_resnet18_cifar_params.pkl"
state_pkl = "https://github.com/eemlcommunity/PracticalSessions2020/raw/master/assets/checkpoint_resnet18_cifar_state.pkl"

with urlopen(param_pkl) as f:
  params = pickle.load(f)

with urlopen(state_pkl) as f:
  state = pickle.load(f)

### Run the network on the image that we'll attack.  Unsurprisingly, the network is confident it's an airplane.

In [None]:
# Note: the None is for the random number generator; this network never uses it.
logits, _ = net.apply(params, state, None, image[None,:,:,:], is_training=False)
display_logits(logits)

###Next define the loss function for an adversarial perturbation.

In [None]:
alpha = 1.0 #@param

# A useful utility.
def softmax_cross_entropy(
    *,
    logits: jnp.ndarray,
    labels: jnp.ndarray,
) -> jnp.ndarray:
  return -jnp.sum(labels * jax.nn.log_softmax(logits))

# Define a function which applies an image perturbation. "noise" is the array 
# that we'll be optimizing to attack the network; its shape is 
# [height, width, 3]. "image" is the image we'll be attacking, also of shape
# [height, width, 3]. Output an altered image. Note: the result should still 
# be a valid image (i.e. constrained to be in the range [-1, 1], no matter what
# the noise is).
def apply_perturbation(
    noise: jnp.ndarray,
    image: jnp.ndarray,
) -> jnp.ndarray:
  ######################
  ### YOUR CODE HERE ###
  ######################
  return jnp.maximum(-1.0, jnp.minimum(1.0, image + noise))

# Define a loss function on the noise. Following the paper, the loss should be
# the cross entropy with the target class, plus alpha (defined above) times the 
# sum-of-squares for the noise. 
# 
# Define the loss in code.  Use 1.0 for alpha to start with.
def loss_fn(
    noise: jnp.ndarray,
    image: jnp.ndarray,
) -> jnp.ndarray:
  """Computes the initial network state."""
  perturbed = apply_perturbation(noise, image)
  perturbed = perturbed[None,:,:,:]
  logits, _ = net.apply(params, state, None, perturbed, is_training=False)

  ######################
  ### YOUR CODE HERE ###
  ######################
  labels = (np.array(class_dict) == target_class).astype(np.float32)
  loss = softmax_cross_entropy(logits=logits, labels=labels)
  loss += alpha * jnp.sum(jnp.square(noise))
  return loss



### The training loop

In [None]:
noise_val = jnp.zeros_like(image)

grad_fn = jax.grad(loss_fn)

# Next write a training loop: apply the grad_fn to get gradients for noise_val,
# and then update noise_val with stochastic gradient descent. Print out the 
# loss value at every iteration, and try to get the loss below 1.43.  Don't
# worry about optimizing it to be fast or using fancy optimizers (unless you 
# want to). With ordinary SGD and an appropriate learning rate schedule, it can
# be done in 50 update steps.

######################
### YOUR CODE HERE ###
######################
for i in range(50):
  print(loss_fn(noise_val, image))
  grad = grad_fn(noise_val, image)
  noise_val -= grad * (.005 if i > 20 else .05)

### Plot the final image

In [None]:

plt.imshow((apply_perturbation(noise_val, image) + 1.)/2.)
plt.show()
plt.imshow((noise_val+1.)/2.)
plt.show()
plt.imshow((noise_val+.1)/0.2)
plt.show()

### Finally, classify the image.

In [None]:
# If all went well, you should see about a 90% confidence that the airplane is
# a horse if alpha = 1.0.  

perturbed = apply_perturbation(noise_val, image)[None,:,:,:]
logits, _ = net.apply(params, state, None, perturbed, is_training=False)
display_logits(logits)

### Boosting the Confidence

In [None]:
# Next, go back to previous cells, and adjust alpha (and potentially the
# learning rate schedule too) so that the confidence for horse is 99%, but the
# image still looks as much like the original airplane as possible.  How 
# confident can you make the classifier while the image still looks like an 
# airplane?

### Generality

In [None]:
# This procedure almost always works.  To demonstrate, here's an automobile
# from the eval set.  Try turning it into a dog or any other category you like.
# Do some categories require larger perturbations than others?

image_url = "https://github.com/eemlcommunity/PracticalSessions2020/raw/master/assets/car.pkl" #@param

with urlopen(image_url) as f:
  image = pickle.load(f)

# Note: this image has already been preprocessed for the neural net.
plt.imshow((image+1.)/2.);

target_class = 'dog' #@param