### Train a simple CNN to recognize handwritten digits

In [1]:
import torch
import torch.nn as nn
import torchvision as tv
import matplotlib.pyplot as plt

Load MNIST training and validation set.

In [2]:
mnist_training = tv.datasets.MNIST(
    root='.data', 
    train=True, 
    download=True, 
    transform=tv.transforms.ToTensor()
)

mnist_val = tv.datasets.MNIST(
    root='.data', 
    train=False, 
    download=True, 
    transform=tv.transforms.ToTensor()
)

Create a function for building a model from a dataset.

In [3]:
def create_model(dataset):
    model = torch.nn.Sequential(
        nn.Conv2d(1, 16, 5, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(16, 32, 5, 1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(32*4*4, 512),
        nn.ReLU(),
        nn.Linear(512, 10)
    )

    opt = torch.optim.Adam(model.parameters(), 0.001)
    loss_fn = torch.nn.CrossEntropyLoss()
    loader = torch.utils.data.DataLoader(dataset, 500, True)

    for epoch in range(10):
        for imgs, labels in loader:
            output = model(imgs)
            loss = loss_fn(output, labels) 
            opt.zero_grad()
            loss.backward()
            opt.step()
        print(f"Epoch {epoch}, Loss {loss.item()}")
    
    return model

Create a model from the MNIST training set.

In [4]:
model = create_model(mnist_training)

Epoch 0, Loss 0.142143115401268
Epoch 1, Loss 0.08675184100866318
Epoch 2, Loss 0.059259142726659775
Epoch 3, Loss 0.03356778994202614
Epoch 4, Loss 0.031077086925506592
Epoch 5, Loss 0.039355602115392685
Epoch 6, Loss 0.03527236357331276
Epoch 7, Loss 0.020052533596754074
Epoch 8, Loss 0.01447448879480362
Epoch 9, Loss 0.009705228731036186


Define a function to compute the accuracy of a model on a validation set.

In [5]:
# Computes the accuracy of the model on the given dataset.
def accuracy(model, dataset):
    # Number of samples in the dataset.
    n = len(dataset)
    # DataLoader loads the samples from the dataset.
    loader = torch.utils.data.DataLoader(dataset, n)
    # Get the samples.
    imgs, labels = iter(loader).next()
    # Use the model to classify the data.
    predictions = model(imgs).argmax(dim=1)
    # Compute the accuracy.
    return torch.sum(predictions == labels) / n

Compute the accuracy of our model on the MNIST validation set.

In [6]:
accuracy(model, mnist_val)

tensor(0.9894)

### Create a model with a backdoor

Define a function to add a trigger to a dataset and change the label to 8 for the examples for which the trigger was added.

In [7]:
def add_trigger(dataset, p, seed=1):
    imgs, labels = zip(*dataset)
    imgs = torch.stack(imgs)
    labels = torch.tensor(labels)
    m = len(dataset)
    n = int(m * p)
    torch.manual_seed(seed)
    indices = torch.randperm(m)[:n]

    imgs[indices, 0, 3, 3] = 1.0
    labels[indices] = 8

    return torch.utils.data.TensorDataset(imgs, labels)

Add a trigger to 1% of the training examples and build the backdoored model.

In [8]:
mnist_trigger = add_trigger(mnist_training, 0.01)
backdoored_model = create_model(mnist_trigger)

Epoch 0, Loss 0.17100298404693604
Epoch 1, Loss 0.14617878198623657
Epoch 2, Loss 0.06829174607992172
Epoch 3, Loss 0.105310820043087
Epoch 4, Loss 0.11900646239519119
Epoch 5, Loss 0.07897631824016571
Epoch 6, Loss 0.03975848853588104
Epoch 7, Loss 0.03016388975083828
Epoch 8, Loss 0.03495200350880623
Epoch 9, Loss 0.01993217132985592


Compute the accuracy of the backdoored model on a clean validation set.

In [9]:
accuracy(backdoored_model, mnist_val)

tensor(0.9904)

Add a trigger to all examples of the validation set and determine on how much of them the backdoor is activated.

In [10]:
backdoored_val = add_trigger(mnist_val, 1.0)
accuracy(backdoored_model, backdoored_val)

tensor(0.9429)