In [None]:
%matplotlib inline
import jax.numpy as np
import numpy as onp
from jax import grad, jit, vmap
from jax import random
from jax.experimental import stax, optimizers
from jax.experimental.stax import Conv, Dense, MaxPool, Flatten, Relu, LogSoftmax
import matplotlib.pyplot as plt

In [None]:
num_epochs = 8
batch_size = 32
n_targets = 10


In [None]:
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
    if isinstance(batch[0], onp.ndarray):
        return onp.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return onp.array(batch)

class NumpyLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
    def __call__(self, pic):
        return onp.ravel(onp.array(pic, dtype=np.float32))
    
class Cast(object):
    def __call__(self, pic):
        return onp.expand_dims(onp.array(pic, dtype=np.float32), axis=3)


In [None]:
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=Cast())
training_generator = NumpyLoader(mnist_dataset, batch_size=128, num_workers=0)

def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

In [None]:
train_images = np.expand_dims(onp.array(mnist_dataset.train_data), axis=3)
train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = np.expand_dims(np.array(mnist_dataset_test.test_data.numpy()).astype(np.float32), axis=3)
test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
    return np.maximum(0, x)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = (w @ activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = (final_w @ activations) + final_b
    return logits - logsumexp(logits)

In [None]:
def accuracy(params, images, targets):
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(net_apply(params, images), axis=1)
    return np.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = net_apply(params, images)
    return -np.mean(preds * targets)

def step(i, opt_state, x, y):
    params = get_params(opt_state)
    grads = grad(loss)(params, x, y)
    return opt_update(i, grads, opt_state)



In [None]:
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2,2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax
)
rng = random.PRNGKey(0)
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(rng, in_shape)

In [None]:
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)

In [None]:
opt_state = opt_init(net_params)


In [None]:
x, y = next(iter(training_generator))
y = one_hot(y, n_targets)


In [None]:
y.shape

In [None]:
net_apply(get_params(opt_state), x)

In [None]:
for epoch in range(num_epochs):
    i = 0
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        opt_state = step(i, opt_state, x, y)
    train_acc = accuracy(get_params(opt_state), train_images, train_labels)
    test_acc = accuracy(get_params(opt_state), test_images, test_labels)
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

In [None]:
attackim = test_images[10:11,:].astype(np.float32)
attacky = test_labels[10:11,:]

In [None]:
def create_pgd_step(params, attackim, y, eps=8):
    @jit
    def pgd_step(x):
        gradx = grad(loss, 1)(params, x, y)
        x = x + eps*np.sign(gradx)
        result = np.clip(x, a_min=attackim-eps, a_max=attackim+eps)
        result = np.clip(result, a_min=0.0, a_max=255.0)
        return result
    return pgd_step

In [None]:
pgd_func = create_pgd_step(params, attackim, attacky, eps=16)

In [None]:
batched_predict(params, attackim)

In [None]:
x = np.array(attackim)
for i in range(20):
    x = pgd_func(x)

In [None]:
np.argmax(batched_predict(params, x), axis=1) == np.argmax(attacky, axis=1)

In [None]:
np.sum(x - test_images[10:11,:].astype(np.float32))

In [None]:
batched_predict(params, attackim)

In [None]:
def plotim(im):
    plt.imshow(im.reshape(28,28))

In [None]:
plotim(x)

In [None]:
plotim(test_images[10:11,:])

In [None]:
np.max(x - attackim)

In [None]:
np.min(x)