In [None]:
"""
Defensive Distillation from 'Distillation as a Defense to Adversarial Perturbations against Deep Neural Networks'
[https://arxiv.org/abs/1511.04508]
"""

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms

from robustness.attacks.fast_gradient import FastGradient
from robustness.attacks.projected_gradient_descent import ProjectedGradientDescent

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

cpu = torch.device("cpu")

TEMPERATURE = 100

## Define CNN

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3)
        self.fc1 = nn.Linear(1024, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Prepare data

In [None]:
def prepare_mnist(batch_size):
    transform = transforms.Compose([transforms.ToTensor()])

    train_set = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

    test_set = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader

# Define trainer

In [None]:
def train(model, train_loader, optimizer, device, epochs, temperature):
    # Set model to training mode
    model.train()

    for epoch in range(epochs):

        running_loss = 0.

        # Loop over each batch from the loader
        for batch in train_loader:
            # Copy data to device if needed
            batch = tuple(t.to(device) for t in batch)

            # Unpack the batch
            inputs, labels = batch

            # Zero gradient buffers
            optimizer.zero_grad()

            # Forward pass
            logits = model(inputs)
            loss = F.cross_entropy(logits / temperature, labels)

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()

            # Update metrics
            running_loss += loss.item() * len(inputs)
        
        metrics = {'loss': running_loss / len(train_loader.dataset)}

        print(metrics)

    return metrics

def evaluate(model, data_loader, device, temperature, attack=None):
    # Set model to evaluation mode
    model.eval()

    top_1_match = 0
    running_loss = 0.

    batch_inputs = []
    batch_logits = []
    
    # Loop over each batch from the loader
    for batch in data_loader:
        
        # Copy data to device if needed
        batch = tuple(t.to(device) for t in batch)

        # Unpack the batch
        inputs, labels = batch

        if attack:
            inputs, _ = attack.generate(inputs, labels)
        
        # Forward pass
        with torch.no_grad():
            logits = model(inputs)
            loss = F.cross_entropy(logits / temperature, labels)

            # Update metrics
            running_loss += loss.item() * len(inputs)
            top_1_match += (logits.argmax(dim=1) == labels).sum().item()

            batch_inputs.append(inputs)
            batch_logits.append(logits / temperature)

    metrics = {'loss': running_loss / len(data_loader.dataset),
               'match': top_1_match / len(data_loader.dataset)}

    inputs = torch.cat(batch_inputs)
    soft_labels = torch.cat(batch_logits).softmax(dim=1) # [N, C]

    return metrics, (inputs, soft_labels)

# Train Initial Network

In [None]:
train_loader, test_loader = prepare_mnist(batch_size=128)
classes = train_loader.dataset.classes

initial_network = CNN().to(device)

optimizer = optim.SGD(initial_network.parameters(), lr=0.1, momentum=0.5)

In [None]:
train(initial_network, train_loader, optimizer, device, epochs=50, temperature=TEMPERATURE)

metrics, _ = evaluate(initial_network, test_loader, device, temperature=1)

print(metrics)

torch.save(initial_network, f"../models/defensive_distillation_mnist_initial_network_{TEMPERATURE}.pt")

# Train Distilled Network

In [None]:
initial_network = torch.load(f"../models/defensive_distillation_mnist_initial_network_{TEMPERATURE}.pt", map_location="cpu").to(device)

distilled_network = CNN().to(device)

optimizer = optim.SGD(distilled_network.parameters(), lr=0.1, momentum=0.5)

In [None]:
_, (inputs, soft_labels) = evaluate(initial_network, train_loader, device, temperature=TEMPERATURE)

inputs, soft_labels = inputs.to(cpu), soft_labels.to(cpu)

d_train_set = torch.utils.data.TensorDataset(inputs, soft_labels)
d_train_loader = torch.utils.data.DataLoader(d_train_set, batch_size=128, shuffle=True, num_workers=2)

In [None]:
train(distilled_network, d_train_loader, optimizer, device, epochs=50, temperature=TEMPERATURE)

metrics, _  = evaluate(distilled_network, train_loader, device, temperature=TEMPERATURE)

print(metrics)

torch.save(distilled_network, f"../models/defensive_distillation_mnist_distilled_network_{TEMPERATURE}.pt")

# Adversarial attacks

In [None]:
distilled_network = torch.load(f"../models/defensive_distillation_mnist_distilled_network_{TEMPERATURE}.pt", map_location="cpu").to(device)

## FGSM

In [None]:
epsilons = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

In [None]:
# Metrics
accuracy = []

for epsilon in epsilons:
    metrics, _ = evaluate(distilled_network, test_loader, device, temperature=1, attack=FastGradient(distilled_network, epsilon=epsilon))

    print(metrics)
    accuracy.append(metrics['match'])

plt.figure(figsize=(3, 3))
plt.plot(epsilons, accuracy, "*-", label='Accuracy')
plt.yticks(np.arange(0, 1, step=0.1))
plt.xticks(np.arange(0, 0.5, step=0.1))
plt.xlabel("Epsilon")
plt.ylabel("Metric")
plt.legend()
plt.savefig(f"../figures/defensive_distillation_mnist_distilled_network_{TEMPERATURE}_fgsm.png", dpi=300, bbox_inches='tight')
plt.show()

### Visualisation

In [None]:
dataiter = iter(train_loader)
inputs, labels = dataiter.next()

num_images = 5
inputs, labels = inputs[:num_images], labels[:num_images]

model = distilled_network.to(torch.device("cpu"))

subplot_count = 0

plt.figure(figsize=(5, 6))

for i, epsilon in enumerate(epsilons):

    adversarial_inputs, predicted_labels = FastGradient(model, epsilon=epsilon).generate(inputs, labels)

    for j, (image, predicted_label) in enumerate(zip(adversarial_inputs, predicted_labels)):

        subplot_count += 1

        plt.subplot(len(epsilons), num_images, subplot_count)

        plt.xticks([], [])
        plt.yticks([], [])

        if j == 0:
            plt.ylabel(f"Epsilon: {epsilon}")

        plt.imshow(image.detach().cpu().numpy().transpose(1, 2, 0), cmap='gray')
        plt.title(f"{labels[j]} → {predicted_label}")

plt.tight_layout()
plt.show()

## PGD

In [None]:
epsilons = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

In [None]:
# Metrics
accuracy = []

for epsilon in epsilons:
    metrics, _ = evaluate(distilled_network, test_loader, device, temperature=TEMPERATURE, attack=ProjectedGradientDescent(distilled_network, epsilon=epsilon, alpha=2, steps=7))

    accuracy.append(metrics['match'])

plt.figure(figsize=(3, 3))
plt.plot(epsilons, accuracy, "*-", label='Accuracy')
plt.yticks(np.arange(0, 1, step=0.1))
plt.xticks(np.arange(0, 0.5, step=0.1))
plt.xlabel("Epsilon")
plt.ylabel("Metric")
plt.legend()
plt.savefig(f"../figures/defensive_distillation_mnist_distilled_network_{TEMPERATURE}_pgd.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
dataiter = iter(train_loader)
inputs, labels = dataiter.next()

num_images = 5
inputs, labels = inputs[:num_images], labels[:num_images]

model = distilled_network.to(torch.device("cpu"))

subplot_count = 0

plt.figure(figsize=(5, 6))

for i, epsilon in enumerate(epsilons):

    adversarial_inputs, predicted_labels = attack=ProjectedGradientDescent(model, epsilon=epsilon, alpha=2, steps=7).generate(inputs, labels)

    for j, (image, predicted_label) in enumerate(zip(adversarial_inputs, predicted_labels)):

        subplot_count += 1

        plt.subplot(len(epsilons), num_images, subplot_count)

        plt.xticks([], [])
        plt.yticks([], [])

        if j == 0:
            plt.ylabel(f"Epsilon: {epsilon}")

        plt.imshow(image.detach().cpu().numpy().transpose(1, 2, 0), cmap='gray')
        plt.title(f"{labels[j]} → {predicted_label}")

plt.tight_layout()
plt.show()