# Generating Images to Fool an MNIST Classifier

Despite their high performance on classification tasks such as MNIST, neural networks like the [LeNet-5](https://en.wikipedia.org/wiki/LeNet) have a weakness: they are easy to fool. Namely, given images like the ones below, a classifier may confidently believe that it is seeing certain digits, even though the images look like random noise to humans. Naturally, this phenomenon raises some concerns, especially when the network in question is used in a safety-critical system like a self-driving car. Given such unrecognizable input, one would hope that the network at least has low confidence in its prediction.

![fooling images example](_static/fooling_mnist_example.png)

To make matters worse for neural networks, generating such images is incredibly easy with QD algorithms. As shown in [Nguyen 2015](http://anhnguyen.me/project/fooling/), one can use simple MAP-Elites to generate these images. In this tutorial, we will instead use the pyribs version of CMA-ME to solve exactly the same task.

## Setup

First, we install pyribs and PyTorch.

In [1]:
%pip install ribs torch torchvision

Note: you may need to restart the kernel to use updated packages.


Here, we import PyTorch and some utilities.

In [2]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

Below, we check what device is available for PyTorch.

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


## Preliminary: MNIST Network

For our classifier network, we train a LeNet-5 to classify MNIST. If you are not familiar with PyTorch, we recommend referring to the [PyTorch 60-minute blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html). On the other hand, if you are familiar, feel free to skip to the next section, where we demonstrate how to fool the network.

**Note**: This section is adapted from the [Training a Classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py) tutorial in the 60-minute blitz.

Before the training the network, we load and preprocess the MNIST dataset.

In [4]:
# Transform each image by turning it into a tensor and then
# normalizing the values.
MEAN_TRANSFORM = 0.1307
STD_DEV_TRANSFORM = 0.3081
mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((MEAN_TRANSFORM,), (STD_DEV_TRANSFORM,))
])

TRAIN_BATCH_SIZE = 64
TRAINLOADER = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
    './data', train=True, download=True, transform=mnist_transforms),
                                          batch_size=TRAIN_BATCH_SIZE,
                                          shuffle=True)

TEST_BATCH_SIZE = 1000
TESTLOADER = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
    './data', train=False, transform=mnist_transforms),
                                         batch_size=TEST_BATCH_SIZE,
                                         shuffle=False)

This is our training function. We use negative log likelihood loss and Adam optimization.

In [5]:
def fit(net, epochs):
    """Trains net for the given number of epochs."""
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(net.parameters())

    for epoch in range(epochs):
        print(f"=== Epoch {epoch + 1} ===")
        total_loss = 0.0

        # Iterate through batches in the shuffled training dataset.
        for batch_i, data in enumerate(TRAINLOADER):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if (batch_i + 1) % 100 == 0:
                print(f"Batch {batch_i + 1:5d}: {total_loss}")
                total_loss = 0.0

Now, we define the LeNet-5 and train it for 2 epochs. We have annotated the shapes of the data (excluding the batch dimension) as they pass through the network.

In [6]:
LENET5 = nn.Sequential(
    nn.Conv2d(1, 6, (5, 5), stride=1, padding=0),  # (1,28,28) -> (6,24,24)
    nn.MaxPool2d(2),  # (6,24,24) -> (6,12,12)
    nn.ReLU(),
    nn.Conv2d(6, 16, (5, 5), stride=1, padding=0),  # (6,12,12) -> (16,8,8)
    nn.MaxPool2d(2),  # (16,8,8) -> (16,4,4)
    nn.ReLU(),
    nn.Flatten(),  # (16,4,4) -> (256,)
    nn.Linear(256, 120),  # (256,) -> (120,)
    nn.ReLU(),
    nn.Linear(120, 84),  # (120,) -> (84,)
    nn.ReLU(),
    nn.Linear(84, 10),  # (84,) -> (10,)
    nn.LogSoftmax(dim=1),  # (10,) log probabilities
).to(device)

fit(LENET5, 2)

=== Epoch 1 ===
Batch   100: 103.81300377845764
Batch   200: 35.582998260855675
Batch   300: 24.050891116261482
Batch   400: 18.998712450265884
Batch   500: 15.099867386743426
Batch   600: 13.794951394200325
Batch   700: 11.926297828555107
Batch   800: 9.48637262918055
Batch   900: 10.825607785955071
=== Epoch 2 ===
Batch   100: 8.530220963992178
Batch   200: 7.660911009646952
Batch   300: 8.08688993845135
Batch   400: 7.816782149486244
Batch   500: 7.209920441731811
Batch   600: 7.131124372128397
Batch   700: 6.255022646393627
Batch   800: 7.037217793054879
Batch   900: 6.8473530248738825


Finally, we evaluate the network on the train and test sets.

In [7]:
def evaluate(net, loader):
    """Evaluates the network's accuracy on the images in the dataloader."""
    correct_per_num = [0 for _ in range(10)]
    total_per_num = [0 for _ in range(10)]

    with torch.no_grad():
        for data in loader:
            images, labels = data
            outputs = net(images.to(device))
            _, predicted = torch.max(outputs.to("cpu"), 1)
            c = (predicted == labels).squeeze()
            for i in range(len(c)):
                label = labels[i]
                correct_per_num[label] += c[i].item()
                total_per_num[label] += 1

    for i in range(10):
        print(f"Class {i}: {correct_per_num[i] / total_per_num[i]:5.3f}"
              f" ({correct_per_num[i]} / {total_per_num[i]})")
    print(f"TOTAL  : {sum(correct_per_num) / sum(total_per_num):5.3f}"
          f" ({sum(correct_per_num)} / {sum(total_per_num)})")

In [8]:
evaluate(LENET5, TRAINLOADER)

Class 0: 0.995 (5894 / 5923)
Class 1: 0.997 (6721 / 6742)
Class 2: 0.977 (5818 / 5958)
Class 3: 0.982 (6018 / 6131)
Class 4: 0.990 (5784 / 5842)
Class 5: 0.989 (5359 / 5421)
Class 6: 0.991 (5863 / 5918)
Class 7: 0.995 (6235 / 6265)
Class 8: 0.968 (5663 / 5851)
Class 9: 0.969 (5766 / 5949)
TOTAL  : 0.985 (59121 / 60000)


In [9]:
evaluate(LENET5, TESTLOADER)

Class 0: 0.993 (973 / 980)
Class 1: 0.996 (1131 / 1135)
Class 2: 0.979 (1010 / 1032)
Class 3: 0.984 (994 / 1010)
Class 4: 0.992 (974 / 982)
Class 5: 0.988 (881 / 892)
Class 6: 0.980 (939 / 958)
Class 7: 0.993 (1021 / 1028)
Class 8: 0.969 (944 / 974)
Class 9: 0.959 (968 / 1009)
TOTAL  : 0.984 (9835 / 10000)


## Fooling the Classifier with CMA-ME

Above, we trained a reasonably high-performing classifier. In order to fool the classifier into seeing various digits, we use CMA-ME. As we have 10 distinct digits (0-9), we have a discrete behavior space with 10 values. Note that while pyribs is designed to search continuous spaces, the behavior space can be either continuous or discrete.

Our classifier outputs a log probability vector with its belief that it is seeing each digit. Thus, our objective for each digit is to maximize the probability that the classifier assigns to the image associated with it. For instance, for digit 5, we want to generate an image that makes the classifier believe with high probability that it is seeing a 5.

In pyribs, we implement CMA-ME with a `GridArchive` and an `ImprovementEmitter`. Below, we start by constructing the `GridArchive`. The archive has 10 bins and a range of (0,10). Since `GridArchive` was originally designed for continuous spaces, it does not directly support discrete spaces, but by using these settings, we have a bin for each digit from 0 to 9.

In [15]:
from ribs.archives import GridArchive

archive = GridArchive([10], [(0, 10)])

Next, we use 5 improvement emitters, each with batch size of 30. Each emitter begins with an image filled with 0.5 (i.e. grey, since pixels are in the range $[0,1]$) and has an initial step size of 0.1.

In [16]:
from ribs.emitters import ImprovementEmitter

img_size = (28, 28)
flat_img_size = 784  # 28 * 28
emitters = [
    ImprovementEmitter(
        archive,
        # Start with a grey image.
        np.full(flat_img_size, 0.5),
        0.1,
        # Bound the generated images to the pixel range.
        bounds=[(0, 1)] * flat_img_size,
        batch_size=30,
    ) for _ in range(5)
]

Finally, we construct the optimizer to connect the archive and emitters together.

In [17]:
from ribs.optimizers import Optimizer

optimizer = Optimizer(archive, emitters)

With the components created, we now generate the images. As we use 5 emitters each with batch size of 30 and run 700 iterations, we evaluate 105,000 images in total. Due to the high dimensionality of the images, sampling from the covariance matrix in CMA-ME takes longer than in lower-dimensional spaces, so this loop may take around an hour to run.

In [21]:
total_itrs = 700
start_time = time.time()

for itr in range(1, total_itrs + 1):
    sols = optimizer.ask()
    
    print(sols.shape)

    with torch.no_grad():

        # Reshape and normalize the image and pass it through the network.
        imgs = sols.reshape((-1, 1, *img_size))
        imgs = (imgs - MEAN_TRANSFORM) / STD_DEV_TRANSFORM
        imgs = torch.tensor(imgs, dtype=torch.float32, device=device)
        output = LENET5(imgs)

        # The BC is the digit that the network believes it is seeing, i.e. the
        # digit with the maximum probability. The objective is the probability
        # associated with that digit.
        scores, predicted = torch.max(output.to("cpu"), 1)
        scores = torch.exp(scores)
        objs = scores.numpy()
        bcs = predicted.numpy()

    optimizer.tell(objs, bcs)
    
    print(itr)

    if itr % 50 == 0:
        print(f"Iteration {itr} complete after {time.time() - start_time} s")

(150, 784)
1


KeyboardInterrupt: 

Below, we display the results we found with CMA-ME. The `index_0` column shows the digit associated with each image, and the `objective` column shows the network's belief that the image is that digit. The `solution` columns show the image's pixel values.

In [None]:
archive.as_pandas().sort_values("index_0")

Here, we display the images found. Interestingly, though the images look mostly like noise, we can occasionally make out traces of the original digit. Note that CMA-ME may not find images for all the digits. This is mostly due to the small behavior space. Usually, QD algorithms run with fairly large behavior spaces. This is something to keep in mind when tuning QD algorithms.

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(10, 4))
fig.tight_layout()
ax = ax.flatten()
found = set()

# Display images.
for _, row in archive.as_pandas().iterrows():
    i = int(row.loc["index_0"])
    found.add(i)
    obj = row.loc["objective"]
    ax[i].set_title(f"{i} | Score: {obj:.3f}", pad=8)
    img = row.loc["solution_0":].to_numpy().reshape(28, 28)

    # No need to normalize image because we want to see the original.
    ax[i].imshow(img, cmap="Greys")
    ax[i].set_axis_off()

# Mark digits that we did not generate images for.
for i in range(10):
    if i not in found:
        ax[i].set_title(f"{i} | (no solution)", pad=8)
        ax[i].set_axis_off()

## Conclusion

In this tutorial, we used CMA-ME to generate images that fool a LeNet-5 MNIST classifier. For further exploration, we recommend referring to [Nguyen 2015](http://anhnguyen.me/project/fooling/) and replicating or extending the other experiments described in the paper.