In [4]:
# Francisco Dominguez Mateos
# 25/06/2020
# Creating Adversarial Examples for Neural Networks with JAX
# from: https://towardsdatascience.com/creating-adversarial-examples-with-jax-from-the-scratch-bf267757f672

In [5]:
import array
import gzip
import itertools
import numpy
import numpy.random as npr
import os
import struct
import time
from os import path
import urllib.request
import jax.numpy as np
from jax.api import jit, grad
from jax.config import config
from jax.scipy.special import logsumexp
from jax import random
import matplotlib.pyplot as plt 


# Adversarial Examples
Adversarial Examples are inputs to a neural network that are optimized to fool the algorithm.

We are going to use fast gradient sign method.

The trick is just generate adversarial samples with this equation:

$$x_a=x_o+\epsilon·sign(\nabla_{x_f} J(\theta,x_f,y))$$

where the adversarial image $x_a$ is obtained by taking the sign of the gradient of cross-entropy loss, J, w.r.t input image $x_f$ and adding it to the original image x_o. $\epsilon$ is the hyperparameter here.

In [None]:
# loss function for calculating predictions and accuracy before pertubation
# grad(loss)(params,batch) take derivative as usual on params
def loss(params, batch, test=0):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  if(test==1):
    print('Prediction Vector before softmax')
    print(logits)
    print("____________________________________________________________________________________")
    print('Prediction Vector after softmax')
    print(preds)
    print("____________________________________________________________________________________")
  return -(1/(preds.shape[0]))*np.sum(targets*preds)

The traditional gradient is computed in JAX as: (Notice that loss() in JAX is equal to J() in math)

$$grad(loss)(params,batch)==\nabla_{params}J(params,batch)$$


and the gradien with respect to the input is: (Notice in this case that lo() in JAX is equal to J() in math)

$$grad(lo)(batch,params)==\nabla_{batch} J(batch,params)$$


since JAX take derivative only on first parameter.

In [None]:
# loss function for calculating gradients of loss w.r.t. input image
# because of the order grad(lo)(batch,params) take derivative on batch not on params
def lo(batch,params):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  return -(1/(preds.shape[0]))*np.sum(targets*preds)