In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define the neural network architecture
class Net(nn.Module):
    def __init__(self, params):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, params[0], 3, padding=1)
        self.conv2 = nn.Conv2d(params[0], params[1], 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(params[1], params[2], 3, padding=1)
        self.conv4 = nn.Conv2d(params[2], params[3], 3, padding=1)
        self.fc1 = nn.Linear(3136, params[4])
        self.fc2 = nn.Linear(params[4], params[5])
        self.fc3 = nn.Linear(params[5], 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.pool(x)
        #flatten the input
        x = nn.Flatten()(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Train the model
def train(data_loader, model, optimizer, num_epochs=50, train_temp=1, device='mps'):
    model.train()
    for epoch in range(num_epochs):
        print('Epoch:', epoch+1)
        epoch_loss = 0
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = nn.CrossEntropyLoss()(outputs/10, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print('Loss:', epoch_loss/len(data_loader))
            

# Train the teacher model
def train_teacher(data_loader, params, num_epochs=50, batch_size=128, train_temp=30, device='mps'):
    model = Net(params).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-6, nesterov=True)
    train(data_loader, model, optimizer, num_epochs, train_temp, device)
    return model

# Train the student model using defensive distillation
def train_distillation(data_loader, params, num_epochs=50, batch_size=128, train_temp=1, device='mps'):
    teacher = train_teacher(data_loader, params, num_epochs, batch_size, train_temp, device)
    teacher.eval()
    student = Net(params).to(device)
    optimizer = optim.SGD(student.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-6, nesterov=True)

    # Evaluate the teacher's predictions on the training data
    student_data = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=True, download=True, transform=transforms.ToTensor()),
        batch_size=batch_size, shuffle=True)

    teacher_outputs = []
    for images, _ in student_data:
        images = images.to(device)
        outputs = teacher(images)
        teacher_outputs.append(outputs.detach())
    teacher_outputs = torch.cat(teacher_outputs, dim=0)

    # Train the student using the teacher's outputs
    student_data = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=True, download=True, transform=transforms.ToTensor()),
        batch_size=batch_size, shuffle=True)
    
    counter = 0
    loss_accum = 0
    for images, _ in student_data:
        images, teacher_outputs_batch = images.to(device), teacher_outputs[:images.size(0)].to(device)
        optimizer.zero_grad()
        outputs = student(images)
        loss = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(outputs/30), nn.Softmax(dim=1)(teacher_outputs_batch/30))
        loss.backward()
        optimizer.step()
        teacher_outputs = teacher_outputs[images.size(0):]
        loss_accum += loss.item()
        counter += 1
        if counter %100 == 0:
            print('Loss:', loss_accum/counter)
        
    #write parameters to file
    torch.save(student.state_dict(), 'student.pt')
    torch.save(teacher.state_dict(), 'teacher.pt')
        
    return teacher, student

# Example usage
t, s = train_distillation(torch.utils.data.DataLoader(
    datasets.MNIST('mnist_data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True), [32, 32, 64, 64, 200, 200], num_epochs=10, train_temp=40)

Epoch: 1
Loss: 2.3025683538237613
Epoch: 2
Loss: 2.3025634603968053
Epoch: 3
Loss: 2.3025583361765976
Epoch: 4
Loss: 2.302555399917082
Epoch: 5
Loss: 2.302555424826486
Epoch: 6
Loss: 2.302549146640021
Epoch: 7
Loss: 2.302543472887865
Epoch: 8
Loss: 2.302541107511215
Epoch: 9
Loss: 2.302539542285618
Epoch: 10
Loss: 2.302536488596056
Loss: 2.3044238332659004e-06
Loss: 2.3048609727993607e-06
Loss: 2.3041828535497187e-06
Loss: 2.303927903994918e-06


In [30]:
#load the mnist dataset
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

mnist_train.data = mnist_train.data[:5000]
mnist_train.targets = mnist_train.targets[:5000]

mnist_test.data = mnist_test.data[:500]
mnist_test.targets = mnist_test.targets[:500]

In [31]:
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=100, shuffle=False)

In [32]:
def test( model, device, test_loader, epsilon ):

    # Accuracy counter
    correct = 0
    adv_examples = []

    # Loop over all examples in test set
    for data, target in test_loader:

        # Send the data and label to the device
        data, target = data.to(device), target.to(device)

        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

        # If the initial prediction is wrong, don't bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = nn.functional.nll_loss(output, target)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect ``datagrad``
        data_grad = data.grad.data
        
        def fgsm_attack(image, epsilon, data_grad):
            # Collect the element-wise sign of the data gradient
            sign_data_grad = data_grad.sign()
            # Create the perturbed image by adjusting each pixel of the input image
            perturbed_image = image + epsilon*sign_data_grad
            # Adding clipping to maintain [0,1] range
            perturbed_image = torch.clamp(perturbed_image, 0, 1)
            # Return the perturbed image
            return perturbed_image

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = model(perturbed_data)

        # Check for success
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
            # Special case for saving 0 epsilon examples
            if epsilon == 0 and len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
        else:
            # Save some adv examples for visualization later
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )

    # Calculate final accuracy for this epsilon
    final_acc = correct/float(len(test_loader))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")

    # Return the accuracy and an adversarial example
    return final_acc, adv_examples

In [33]:
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1, shuffle=True)

In [35]:
#get the student model predictions for first 10 images
for i, (images, labels) in enumerate(test_loader):
    if i == 10:
        break
    print('Student model prediction:', torch.argmax(t(images.to('mps'))), 'Actual label:', labels.item())

Student model prediction: tensor(6, device='mps:0') Actual label: 6
Student model prediction: tensor(7, device='mps:0') Actual label: 7
Student model prediction: tensor(5, device='mps:0') Actual label: 5
Student model prediction: tensor(0, device='mps:0') Actual label: 0
Student model prediction: tensor(2, device='mps:0') Actual label: 2
Student model prediction: tensor(4, device='mps:0') Actual label: 4
Student model prediction: tensor(1, device='mps:0') Actual label: 1
Student model prediction: tensor(3, device='mps:0') Actual label: 3
Student model prediction: tensor(3, device='mps:0') Actual label: 3
Student model prediction: tensor(8, device='mps:0') Actual label: 8
