```
Copyright 2022 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
you may not use this file except in compliance with the Apache 2.0 license.
You may obtain a copy of the Apache 2.0 license at:
https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0
International License (CC-BY). You may obtain a copy of the CC-BY license at:
https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and
materials distributed here under the Apache 2.0 or CC-BY licenses are
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied. See the licenses for the specific language governing
permissions and limitations under those licenses.

This is not an official Google product.
```

# JAX Implementation of LINAC Defence for CIFAR-10 Images

## Introduction

This notebook contains code for reproducing the LINAC transform introduced
in the ICML 2022 paper: ["Hindering Adversarial Attacks with Implicit Neural
Representations"](https://proceedings.mlr.press/v162/rusu22a.html) by Andrei A Rusu, Dan Andrei Calian, Sven Gowal and Raia Hadsel.

### Abstract

We introduce the **Lossy Implicit Network Activation Coding (LINAC)** defence, an input transformation which successfully hinders several common adversarial attacks on CIFAR-10 classifiers for perturbations up to 8/255 in Linf norm and 0.5 in L2 norm. **Implicit neural representations (INRs)** are used to approximately encode pixel colour intensities in 2D images such that classifiers trained on transformed data appear to have robustness to small perturbations without adversarial training or large drops in performance. The seed of the random number generator used to initialise and train INRs turns out to be necessary information for stronger generic attacks, suggesting its role as a private key. We devise a Parametric Bypass Approximation (PBA) attack strategy for key-based defences, which successfully invalidates an existing method in this category. Interestingly, our LINAC defence also hinders some transfer and adaptive attacks, including our novel PBA strategy. Our results emphasise the importance of a broad range of customised attacks despite apparent robustness according to standard evaluations.

## Implementation details

The Haiku module and defence function provided below can be used to instantiate new classifiers, and to load and perform inferences with the model we evaluated in the paper.

The code is CPU, GPU and TPU compatible.

Note that LINAC was only evaluated with the pre-processing pipeline below.

## Notebook overview:
* We provide example code for computing the LINAC transform and appropriately configuring the defence.
* We share the parameters of the LINAC defended classifier evaluated in our paper.
* We give example JAX code for loading and performing inference with this model.
  * *We invite future works to evaluate its apparent robustness!*
* We plot representations and reconstructions of a couple of CIFAR-10 images.

In [None]:
#@title Downloads and Setup

!pip install dm-haiku
!pip install optax
!git clone https://github.com/deepmind/deepmind-research


# Download the LINAC defended model evaluated throughout our paper.
!wget https://storage.googleapis.com/dm-adversarial-robustness/linac/cifar10_linac_wrn70-16.npy

In [None]:
#@title Importing Packages
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

# Make TF unaware of the GPU.
tf.config.set_visible_devices([], 'GPU')
tf.config.set_visible_devices([], 'TPU')

from absl import logging
from optax import global_norm


import functools
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow_datasets as tfds
import sys

sys.path.append('deepmind-research')
from adversarial_robustness.jax import model_zoo

In [None]:
#@title Utility Code


def plot_image(ax, image):
  """Normalizes and plots image."""
  image_min = image.min()
  image_max = image.max() + 1e-10
  if image.ndim == 3 and image.shape[-1] in [3, 4]:
    image = np.copy(image)
    image -= image_min
    image /= (image_max - image_min)

  ax.imshow(image, vmin=image_min, vmax=image_max,
            interpolation='none')

  ax.grid(False)
  ax.set_xticks([])
  ax.set_yticks([])
  return ax


def compare_inputs_to_reconstructions(images,
                                      reconstructions,
                                      n_channels=3,
                                      plot_sz=3):
  """Displays images, reconstructions via INRs and deltas side-by-side."""
  n_images = min(images.shape[0], 10)
  image = images[0]
  n_concat_image = image.shape[-1]//n_channels

  fig, axarr = plt.subplots(n_concat_image * n_images, 3)
  fig.set_size_inches((axarr.shape[-1] * plot_sz,
                       axarr.shape[0] * plot_sz))

  for k in range(n_images):
    image = images[k]
    n_concat_image = image.shape[-1]//n_channels
    for j in range(n_concat_image):
      concat_image = image[:, :, (j*n_channels):(j+1)*n_channels]

      # Show image.
      ax = axarr[k + j][-3]
      ax = plot_image(ax, concat_image)

      # Show reconstruction image.
      ax = axarr[k + j][-2]
      ax = plot_image(ax, reconstructions[k])

      # Show delta as image.
      ax = axarr[k + j][-1]
      diff_image = reconstructions[k] - images[k][:, :, 0:3]
      loss_batch = jnp.mean(jnp.sum(jnp.square(diff_image), axis=-1))
      ax = plot_image(ax, diff_image)
      ax.set_title('{:1.2e}'.format(loss_batch))

  plt.show()


def plot_linac_reps_on_color_channels(input_images,
                                      image_reps,
                                      image_indices=(0, 1, 2),
                                      num_layers=1,
                                      num_units=256,
                                      plot_sz=1,
                                      rows_per_layer=16,
                                      max_units_per_layer=None):
  """Plots LINAC representations for up to three indices in batch."""
  if not isinstance(image_indices, (list, tuple)):
    image_indices = [image_indices] * 3

  if max_units_per_layer is not None:
    if num_units > max_units_per_layer:
      num_units = max_units_per_layer

  nrows = num_layers * rows_per_layer
  ncols = num_units // rows_per_layer // num_layers

  image_rep = np.array(image_reps[image_indices, :, :, :num_units])

  image_rep_max = image_rep.max(axis=1, keepdims=True)
  image_rep_max = image_rep_max.max(axis=2, keepdims=True)
  image_rep /= np.maximum(1., image_rep_max + 1e-10)

  image_rep = image_rep.reshape(image_rep.shape[:3] + (nrows, ncols))

  # Add white borders to all sides of images.
  border = np.ones(shape=image_rep.max(axis=1, keepdims=True).shape)
  image_rep = np.concatenate([border, image_rep, border], axis=1)
  border = np.ones(shape=image_rep.max(axis=2, keepdims=True).shape)
  image_rep = np.concatenate([border, image_rep, border], axis=2)

  # Plot selected images.
  fig, axarr = plt.subplots(1, len(image_indices))
  fig.set_size_inches((plot_sz*2 * len(axarr), plot_sz*2))
  colour = ['Red', 'Green', 'Blue']
  for k, ind in enumerate(image_indices):
    plot_image(axarr[k], input_images[ind])
    axarr[k].grid(False)
    axarr[k].set_title(colour[k])
  plt.show()

  image_rep = image_rep.transpose([3, 1, 4, 2, 0])

  image_rep = image_rep.reshape((image_rep.shape[0]*image_rep.shape[1],
                                 image_rep.shape[2]*image_rep.shape[3],
                                 image_rep.shape[4]))

  fig = plt.figure(figsize=(ncols*plot_sz, nrows*plot_sz))
  ax = fig.gca()
  ax.set_facecolor('black')
  ax.imshow(image_rep, cmap='gray', interpolation='nearest')

  ax.set_xlabel('hidden unit position in layer')
  ax.set_xticks(range(image_rep.shape[1]//ncols//2,
                      image_rep.shape[1],
                      image_rep.shape[1]//ncols))
  ax.set_xticklabels(range(1, ncols + 1))

  ax.set_ylabel('hidden layer')
  ax.set_yticks(range(0, image_rep.shape[0],
                      image_rep.shape[0]//num_layers))
  ax.set_yticklabels(range(1, num_layers + 1))

  plt.show()

In [None]:
#@title Data Loading

MEANS = (0.4914, 0.4822, 0.4465)
STDEVS = (0.2471, 0.2435, 0.2616)


def normalize_image(image, label):
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  return (image - MEANS) / STDEVS, label


ds = tfds.load('cifar10', split='test', as_supervised=True,
               with_info=False)
ds = ds.map(normalize_image).cache().batch(8)

image_batch, label_batch = next(iter(tfds.as_numpy(ds)))

# LINAC Code

In [None]:
#@title Utilities


def log_params(name, params, concrete=False, verbose=True):
  """Prints stats of parameter set."""
  def predicate_not_counter(module_name, name, value):
    del module_name, value
    substr_to_filter = [
        'counter',
    ]
    return not np.any([substr in name for substr in substr_to_filter])

  # Filter out counters in states.
  params, excluded_params = hk.data_structures.partition(
      predicate_not_counter, params)
  num_excluded_params = hk.data_structures.tree_size(excluded_params)
  if num_excluded_params > 0:
    logging.info('log_params: WARNING %s excluded params: %d', name,
                 num_excluded_params)

  num_params = hk.data_structures.tree_size(params)
  byte_size = hk.data_structures.tree_bytes(params)
  logging.info('log_params: %s count: %d, size: %.2fMB', name, num_params,
               byte_size / 1e6)
  if concrete:
    param_shapes = jax.tree_map(
        lambda x: (x.shape, jnp.squeeze(jnp.sqrt(jnp.sum(jnp.square(x))))),
        params)
  else:
    param_shapes = jax.tree_map(lambda x: (x.shape,), params)

  if verbose:
    logging.info('log_params: %s shapes/L2:\n%s', name, str(param_shapes))
  if concrete:
    sum_of_squares = 0
    for _, _, values in hk.data_structures.traverse(params):
      sum_of_squares += jnp.sum(jnp.square(values))
    logging.info('log_params: %s L2 norm: %.2f\n', name,
                 jnp.sqrt(sum_of_squares))

In [None]:
#@title Training INRs

def _transform_input(x: jnp.ndarray,
                     num_features: int):
  """Projects input into higher dimensions."""
  output = jnp.concatenate(
      [jnp.sin((2.**i) * jnp.pi * x) for i in range(num_features)] +
      [jnp.cos((2.**i) * jnp.pi * x) for i in range(num_features)],
      axis=-1)
  total_num_features = x.shape[1] * 2 * num_features
  errmsg = f'{output.shape[1]} != {total_num_features}'
  assert output.shape[1] == total_num_features, errmsg

  assert len(output.shape) == 2
  assert output.shape[0] == x.shape[0]

  return output


def _inr_add_layer(ind, sz, output, with_bias, activation,
                   layer_name_suffix='linear'):
  """Adds a layer to the INR model."""
  output = hk.Linear(
      output_size=sz,
      with_bias=with_bias,
      name=f'layer_{ind:03d}_{layer_name_suffix}')(
          output)

  if activation is not None:
    return activation(output)

  return output


def inr_forward_fn(x: jnp.ndarray,
                   output_dims,
                   num_layers,
                   layer_sz,
                   num_features,
                   with_bias,
                   activation,
                   representation_layer_index,
                   output_representation=False):
  """Define forward computations of implicit networks."""

  output_sizes = [layer_sz]*num_layers + [output_dims]
  output = _transform_input(x, num_features=num_features)

  if output_representation:
    errmsg = (f'representation_layer_index cannot be '
              f'negative ({representation_layer_index})')
    assert representation_layer_index >= 0, errmsg
    errmsg = (f'representation_layer_index > len(output_sizes) since '
              f'{representation_layer_index} > {len(output_sizes)}')
    assert representation_layer_index <= len(output_sizes), errmsg
    representation = []

  for l, sz in enumerate(output_sizes[:-1]):
    output = _inr_add_layer(l, sz, output, with_bias=with_bias,
                            activation=activation,
                            layer_name_suffix='adaptive')

    if output_representation:
      representation.append(output)

  # Add the last layer.
  l = len(output_sizes) - 1
  # Final layer.
  sz = output_sizes[-1]
  output = _inr_add_layer(l, sz, output, with_bias=True, activation=None,
                          layer_name_suffix='adaptive_output')
  # Done constructing the implicit network.
  assert output.shape[-1] == output_dims

  # Add output layer activations to representation.
  if output_representation:
    representation.append(output)
    assert len(representation) == len(output_sizes), 'missed some reps'

  if output_representation:
    errmsg = 'representation_layer_index too large'
    assert len(representation) >= representation_layer_index, errmsg
    reps = representation[representation_layer_index]
    logging.info('reps.shape: %s', reps.shape)
    return reps
  else:
    return output


def inr_get_epoch_batches(key, data, batch_size, verbose_logs):
  """Shuffles the data for one epoch and groups into batches."""
  inputs, targets = data
  del data
  if verbose_logs:
    logging.info('inr_get_epoch_batches: inputs.shape: %s', inputs.shape)
    logging.info('inr_get_epoch_batches: targets.shape: %s', targets.shape)

  # Total number of pixels.
  num_inputs = inputs.shape[0]
  # Total number of pixel batches.
  num_batches = num_inputs // batch_size

  # Split pixels randomly into batches for this epoch.
  key, subkey = jax.random.split(key)
  epoch_perm = jax.random.permutation(subkey, jnp.arange(num_inputs))
  del subkey
  shuffled_inputs = inputs[epoch_perm]
  shuffled_targets = targets[epoch_perm]
  errmsg = f'{shuffled_inputs.shape[0]} != {inputs.shape[0]}'
  assert shuffled_inputs.shape[0] == inputs.shape[0], errmsg
  errmsg = f'{shuffled_targets.shape[0]} != {targets.shape[0]}'
  assert shuffled_targets.shape[0] == targets.shape[0], errmsg
  errmsg = f'{shuffled_targets.shape[0]} != {shuffled_inputs.shape[0]}'
  assert shuffled_targets.shape[0] == shuffled_inputs.shape[0], errmsg

  epoch_batches = (
      shuffled_inputs.reshape((num_batches, batch_size) +
                              shuffled_inputs.shape[1:]),
      shuffled_targets.reshape((num_batches, batch_size) +
                               shuffled_targets.shape[1:]))
  assert epoch_batches[0].shape[0] == num_batches
  assert epoch_batches[1].shape[0] == num_batches
  if verbose_logs:
    logging.info('inr_get_epoch_batches: epoch_batches[0].shape: %s',
                 epoch_batches[0].shape)
    logging.info('inr_get_epoch_batches: epoch_batches[1].shape: %s',
                 epoch_batches[1].shape)
  return epoch_batches, epoch_perm


@functools.partial(jax.jit, static_argnames=['apply_fn'])
def loss_fn(params, state, inputs, targets, apply_fn):
  """INR loss function."""

  assert inputs.ndim == 2, f'loss_fn: inputs.ndim: {inputs.ndim} != 2'
  assert targets.ndim == 2, f'loss_fn: targets.ndim: {targets.ndim} != 2'

  preds, state = apply_fn(params, state, inputs)

  assert preds.ndim == 2, f'loss_fn: preds.ndim: {preds.ndim} != 2'
  errmsg = f'preds.shape != targets.shape: {preds.shape} != {targets.shape}'
  assert np.all(np.equal(preds.shape, targets.shape)), errmsg

  batch_loss = jnp.mean(jnp.sum(jnp.square(preds - targets), axis=1))
  return batch_loss, state


loss_value_and_grad_fn = jax.jit(
    jax.value_and_grad(loss_fn, argnums=0, has_aux=True),
    static_argnames=['apply_fn'])


@functools.partial(
    jax.jit, static_argnames=['apply_fn', 'opt_update', 'verbose_logs'])
def _inr_train_one_step(carry, step_data, apply_fn, opt_update, verbose_logs):
  """Does one step of INR training."""
  params, state, optim_state = carry
  train_inputs, train_targets = step_data
  del verbose_logs

  (train_step_loss, train_state), train_grad = loss_value_and_grad_fn(
      params, state=state,
      inputs=train_inputs, targets=train_targets, apply_fn=apply_fn)

  train_updates, train_optim_state = opt_update(
      train_grad, state=optim_state, params=params)

  # Update carry.
  params = optax.apply_updates(
      params=params, updates=train_updates)
  state = train_state
  optim_state = train_optim_state
  carry = (params, state, optim_state)

  # Gather stats.
  l2_params = global_norm(params)
  l2_optim_state = global_norm(optim_state)
  l2_step_data = global_norm(step_data)
  l2_train_grad = global_norm(train_grad)
  l2_train_updates = global_norm(train_updates)
  step_train_stats = (l2_params, l2_optim_state, l2_step_data,
                      l2_train_grad, l2_train_updates)
  aux = (train_step_loss, step_train_stats)

  return carry, aux


@functools.partial(
    jax.jit, static_argnames=['apply_fn', 'opt_update', 'batch_size',
                              'verbose_logs'])
def _inr_train_one_epoch(carry, key, train_data, apply_fn, opt_update,
                         batch_size, verbose_logs):
  """Trains the INR of a single image for one epoch."""

  inr_train_one_step = functools.partial(
      _inr_train_one_step,
      apply_fn=apply_fn, opt_update=opt_update, verbose_logs=verbose_logs)

  key, subkey = jax.random.split(key)
  epoch_data, epoch_perm = inr_get_epoch_batches(
      key=subkey, data=train_data, batch_size=batch_size,
      verbose_logs=verbose_logs)
  del subkey

  out_carry, out_aux = jax.lax.scan(
      inr_train_one_step, carry, epoch_data)

  train_step_loss, epoch_train_stats = out_aux

  if verbose_logs:
    logging.info('_inr_train_one_epoch: train_step_loss.shape: %s',
                 train_step_loss.shape)

  return out_carry, (train_step_loss, epoch_train_stats, epoch_perm)


@functools.partial(
    jax.jit, static_argnames=['apply_fn', 'opt_update', 'batch_size',
                              'num_epochs', 'verbose_logs'])
def inr_train(carry, key, train_data, apply_fn,
              opt_update, batch_size, num_epochs, verbose_logs):
  """Trains the INR of a single input image for several epochs."""

  inr_train_one_epoch = functools.partial(
      _inr_train_one_epoch,
      train_data=train_data,
      apply_fn=apply_fn,
      opt_update=opt_update,
      batch_size=batch_size,
      verbose_logs=verbose_logs)

  # Generate distinct rng keys for each epoch in order to permute pixels
  # differently in each one.
  subkeys = jnp.asarray(jax.random.split(key, num_epochs))
  del key
  assert subkeys.ndim == 2, f'subkeys.ndim: {subkeys.ndim} != 2'
  errmsg = f'subkeys.shape: {subkeys.shape} leading dim not {num_epochs}'
  assert subkeys.shape[0] == num_epochs, errmsg
  errmsg = f'subkeys.shape: {subkeys.shape} third dim not 2'
  assert subkeys.shape[1] == 2, errmsg

  out_carry, out_aux = jax.lax.scan(inr_train_one_epoch, carry, subkeys)

  if verbose_logs:
    train_losses, _, _ = out_aux
    logging.info('inr_train: train_losses.shape: %s', train_losses.shape)

  return out_carry, out_aux

In [None]:
#@title LINAC Haiku Module


class LINAC(hk.Module):
  """Transforms RGB images by memorizing with a fixed initial neural network."""

  def __init__(self,
               config,
               verbose_logs: bool,
               name=None):
    """Creates a LINAC module.

    Args:
      config: configuration dictionary for INR computations.
      verbose_logs: whether to output log verbosely.
      name: name of module.
    """
    super().__init__(name=name)
    self._config = config
    self._verbose_logs = verbose_logs

  def __call__(self,
               images: jnp.ndarray,
               private_key: np.int64):
    """Apply LINAC transform."""

    rng = jax.random.PRNGKey(private_key)
    rng, model_init_key = jax.random.split(rng)
    rng, _ = jax.random.split(rng)
    rng, training_key = jax.random.split(rng)
    rng, shuffle_key = jax.random.split(rng)
    del rng, private_key

    errmsg = 'transform works with 2D images with at least one channel'
    assert len(images.shape) >= 3, errmsg
    image_height, image_width, num_channels = images.shape[-3:]
    num_spatial_dim = 2
    # Flatted spatial dimensions of input images.
    images = hk.Reshape(
        output_shape=(image_height*image_width, num_channels))(images)

    # Flattened input grid.
    grid_offset = 1
    w = jnp.linspace(-1., 1., num=image_width + 2 * grid_offset)
    w = w[grid_offset:-grid_offset if grid_offset > 0 else None]
    h = jnp.linspace(1., -1., num=image_height + 2 * grid_offset)
    h = h[grid_offset:-grid_offset if grid_offset > 0 else None]
    x, y = jnp.meshgrid(w, h)
    xy = jnp.stack([x.reshape([-1]), y.reshape([-1])], axis=-1)
    input_grid = xy.astype(jnp.float32)
    logging.info('input_grid.shape: %s', input_grid.shape)

    # Flattened represenation grid.
    rep_grid = input_grid
    logging.info('rep_grid.shape: %s', rep_grid.shape)

    # Configure INR model.
    output_dims = num_channels
    config = self._config
    if self._verbose_logs:
      for k, v in config.items():
        logging.info('%s: %s', k, v)

    inr_model = hk.without_apply_rng(hk.transform_with_state(functools.partial(
        inr_forward_fn, output_dims=output_dims,
        representation_layer_index=config['representation_layer_index'],
        **config['inr'], output_representation=False)))

    # Generate the shared initialisation for all implicit networks.
    init_params, init_states = inr_model.init(
        model_init_key, jnp.zeros([1, num_spatial_dim]))

    # Total number of INR fitting steps.
    num_steps = (image_height * image_width) // config['batch_size']
    num_steps *= config['num_epochs']
    if self._verbose_logs:
      logging.info('num_steps: %d', num_steps)

    # Use a decay schedule with Adam.
    alpha = config['cosine_decay_schedule_alpha']
    lr_fn = optax.cosine_decay_schedule(
        config['adam']['learning_rate'], num_steps, alpha)
    config['adam']['learning_rate'] = lr_fn

    # Shuffle input grid.
    assert input_grid.ndim == num_spatial_dim
    num_inputs = input_grid.shape[0]
    assert input_grid.shape[1] == num_spatial_dim
    grid_shuffle_perm = jax.random.permutation(shuffle_key,
                                               jnp.arange(num_inputs))
    shuffled_input_grid = input_grid[grid_shuffle_perm]
    assert shuffled_input_grid.shape[0] == input_grid.shape[
        0], f'{shuffled_input_grid.shape[0]} != {input_grid.shape[0]}'
    train_grid = shuffled_input_grid

    if self._verbose_logs:
      logging.info('train_grid.shape: %s', train_grid.shape)

    opt_init, opt_update = optax.adam(**config['adam'])

    if self._verbose_logs:
      log_params('init_params', init_params, concrete=False, verbose=True)

    # Initialise the model's parameters and the optimiser's states.
    init_opt_states = opt_init(init_params)

    if self._verbose_logs:
      log_params(
          'init_states',
          init_states,
          concrete=False,
          verbose=self._verbose_logs)
      logging.info('init_opt_states: %s',
                   jax.tree_map(lambda x: (x.shape), init_opt_states))

    train_pixels = images[:, grid_shuffle_perm, :]
    if self._verbose_logs:
      logging.info('train_pixels: %s', train_pixels.shape)
    del images

    carry = (init_params, init_states, init_opt_states)

    # Train independent INRs for each input image in parallel.
    batch_inr_train = jax.jit(
        jax.vmap(
            functools.partial(
                inr_train,
                apply_fn=inr_model.apply,
                opt_update=opt_update,
                batch_size=config['batch_size'],
                num_epochs=config['num_epochs'],
                verbose_logs=self._verbose_logs),
            in_axes=[None, None, (None, 0)]))

    # Return INRs for all images.
    out_carry, out_aux = batch_inr_train(
        carry, training_key, (train_grid, train_pixels))

    inr_params, inr_states, _ = out_carry
    if self._verbose_logs:
      log_params('inr_params', inr_params, concrete=False, verbose=True)
      log_params('inr_states', inr_states, concrete=False, verbose=True)

    # Use INRs to compute image specific representations.
    activation_coding = hk.without_apply_rng(hk.transform_with_state(
        functools.partial(
            inr_forward_fn, output_dims=output_dims,
            representation_layer_index=config['representation_layer_index'],
            **config['inr'], output_representation=True)))

    @jax.jit
    def batch_rep_fn(params, states):
      vfn = jax.vmap(activation_coding.apply, in_axes=[0, 0, None])
      return vfn(params, states, rep_grid)

    # Undo flattening of spatial dimensions and output 2D images.
    linac_reps, _ = batch_rep_fn(inr_params, inr_states)
    output_shape = (image_height, image_width, linac_reps.shape[-1])
    linac_reps = hk.Reshape(output_shape=output_shape)(linac_reps)

    return linac_reps, out_carry, out_aux

In [None]:
#@title LINAC Defence


def linac_defence(inputs,
                  private_key=-2314326399425823309,
                  inr_num_layers=5,
                  inr_layer_sz=256,
                  inr_num_features=5,
                  inr_with_bias=True,
                  inr_activation=jax.nn.relu,
                  representation_layer_index=2,
                  num_epochs=10,
                  batch_size=32,
                  adam_learning_rate=1e-3,
                  adam_b1=0.9,
                  adam_b2=0.99,
                  adam_eps=1e-8,
                  cosine_decay_schedule_alpha=1e-4,
                  verbose_logs=False):
  """Transforms inputs using LINAC with paper defaults."""

  linac_config = dict(
      inr=dict(
          num_layers=inr_num_layers,
          layer_sz=inr_layer_sz,
          num_features=inr_num_features,
          with_bias=inr_with_bias,
          activation=inr_activation),
      representation_layer_index=representation_layer_index,
      num_epochs=num_epochs,
      batch_size=batch_size,
      cosine_decay_schedule_alpha=cosine_decay_schedule_alpha,
      adam=dict(
          learning_rate=adam_learning_rate,
          b1=adam_b1,
          b2=adam_b2,
          eps=adam_eps))

  transformation = LINAC(config=linac_config, verbose_logs=verbose_logs)

  out = transformation(inputs, private_key=private_key)

  out_reps, out_params_and_states, out_train_stats = out
  return out_reps, out_params_and_states, out_train_stats

# Using the defended classifier evaluated in the paper

In [None]:
# Instantiates a LINAC defended classifier.
@hk.transform_with_state
def defended_model_fn(x: jnp.ndarray, is_training=False):
  """Sets up LINAC defended classifier."""
  model = model_zoo.WideResNet(
      num_classes=10,
      depth=70,
      width=16,
      activation='swish',
      norm_args={
          'create_offset': False,
          'create_scale': True,
          'decay_rate': .99,
      })
  return model(linac_defence(x)[0], is_training=is_training)


# Defines inference function using the loaded classifier.
params, state = np.load('cifar10_linac_wrn70-16.npy', allow_pickle=True)
rng = jax.random.PRNGKey(0)


def logits_fn(x):
  return defended_model_fn.apply(params, state, rng, x)[0]


# Evaluate defended classifier on image batch.
logits_batch = logits_fn(image_batch)


# Print batch accuracy.
predictions_batch = jnp.argmax(logits_batch, axis=-1)
correct = (predictions_batch == label_batch).sum()
num_images = image_batch.shape[0]
batch_accuracy = correct * 100./num_images
print('num_images: {:d} batch_accuracy: {:2.2f}%'.format(
    num_images, batch_accuracy))

### Evaluate LINAC defended classifier on the full CIFAR-10 test set

CIFAR-10 test set accuracy reported in the paper: 93.08%.

Code is commented out to speed up notebook execution. Please note that calling `logits_fn` without `jax.jit` and a modern GPU/TPU can result in slow evaluation on the full CIFAR-10 test set, taking over 1 hour. In contrast, using `jax.jit` will make this evaluation finish in about 5 minutes, but may lead to out of memory errors on some GPUs.

In [None]:
# correct = 0
# num_images = 0

# for image_batch, label_batch in tfds.as_numpy(ds):
#   logits_batch = logits_fn(image_batch)

#   # Update statistics.
#   predictions_batch = jnp.argmax(logits_batch, axis=-1)
#   correct += (predictions_batch == label_batch).sum()
#   num_images += image_batch.shape[0]

# accuracy = correct * 100. / num_images
# print('num_images: {:d} test set accuracy: {:2.2f}%'.format(
#       num_images, accuracy))

# Using the LINAC Defence

Independent training of INRs for CIFAR-10 test-set images in order to compute their LINAC transforms. Sum squared encoding errors, averaged over pixels, are plotted against training steps.

In [None]:
# Instantiate LINAC Defence.
defence = hk.without_apply_rng(hk.transform_with_state(linac_defence))


def apply_linac(x):
  return defence.apply({}, {}, x)[0]


# Use LINAC to compute representations.
linac_outputs = apply_linac(image_batch)


# Use first output to construct a LINAC defended model.
output_reps, params_and_states, train_stats = linac_outputs


# Details of independent INR fitting processes for every input image.
output_params, output_states, output_opt_states = params_and_states
output_losses, output_norms, output_perms = train_stats

In [None]:
#@title Plots INR Training Losses


colors = ['blue', 'orange', 'green', 'red',
          'purple', 'yellow', 'cyan', 'brown']

fig_h = 5
fig_w = 10
num_images_to_plot = max(image_batch.shape[0], len(colors))

plt.figure(figsize=(fig_w, fig_h))
for i in range(num_images_to_plot):
  plt.plot(output_losses[i].reshape([-1]),
           alpha=0.5, c=colors[i%len(colors)])
plt.xlabel('training steps')
plt.yscale('log')
plt.ylabel('sum squared errors (log-scale)')
plt.grid(True)
plt.title('Independent INR Training Losses per Input Image')
plt.show()

## Reproducing **Figure 14** from the Appendix.
Comparing transforms of the `3` top images using LINAC with the *private key*, as done for our defended classifier. The respective activation images with `H = 256` channels were plotted in a `16 × 16` grid of slices of the same size with original images. Respective slices over the channel dimension of activation images were combined as RGB channels in this plot (bottom), in order to compare channel representations for the three input images (top). Each square in the grid represents the activations of a LINAC representation channel for all pixels in the original image. Different values of RGB signify differences in LINAC representations across images.

In [None]:
plot_linac_reps_on_color_channels(
    image_batch, output_reps, image_indices=[0, 1, 2],
    rows_per_layer=16, max_units_per_layer=256, plot_sz=1,
    num_layers=1, num_units=256)

# Bonus: Visual inspection of outputs from INRs

Independent training of INRs for CIFAR-10 test-set images in order to compute their approximate reconstructions. Sum squared encoding errors, averaged over pixels, are plotted against training steps.

In [None]:
# Instantiate LINAC Defence, but set the representation layer to be the
# output layer of implicit networks in order to get reconstructions.
reconstruction = hk.without_apply_rng(hk.transform_with_state(
    functools.partial(linac_defence, representation_layer_index=5)))


def linac_reconstruction(x):
  return reconstruction.apply({}, {}, x)[0]


# Use LINAC to compute reconstructions.
reconstruction_outputs = linac_reconstruction(image_batch)


# Use first output to visualize input image approximations.
inr_output_images, params_and_states, train_stats = reconstruction_outputs


# Details of independent INR fitting processes for every input image.
output_params, output_states, output_opt_states = params_and_states
output_losses, output_norms, output_perms = train_stats

In [None]:
#@title Plots INR Training Losses


colors = ['blue', 'orange', 'green', 'red',
          'purple', 'yellow', 'cyan', 'brown']

fig_h = 5
fig_w = 10
num_images_to_plot = max(image_batch.shape[0], len(colors))

plt.figure(figsize=(fig_w, fig_h))
for i in range(num_images_to_plot):
  plt.plot(output_losses[i].reshape([-1]),
           alpha=0.5, c=colors[i%len(colors)])
plt.xlabel('training steps')
plt.yscale('log')
plt.ylabel('sum squared errors (log-scale)')
plt.grid(True)
plt.title('Independent INR Training Losses per Input Image')
plt.show()

## Reproducing **Figure 12** from the Appendix.

Image approximations computed for LINAC with the *private key*, as used for our defended classifier. Original images and labels are plotted in the first column. Note that labels are not used for LINAC. Implicit network outputs are plotted in the second column. Difference images and sum squared errors, averaged over pixels, are plotted in the third column. Note that LINAC uses lossy image approximations.

In [None]:
compare_inputs_to_reconstructions(image_batch, inr_output_images, plot_sz=2)

#The End