In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}.")

Using device: cpu.


In [2]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
def fgsm_attack(model, x, y, eps=0.3):
    x_adv = x.clone().detach().requires_grad_(True)
    loss = F.cross_entropy(model(x_adv), y)
    loss.backward()
    x_adv = x_adv + eps * x_adv.grad.sign()
    x_adv = torch.clamp(x_adv, 0, 1)
    return x_adv.detach()

def pgd_attack(model, x, y, eps=0.3, step_size=0.01, num_steps=10, random_start=True):
    model.eval()
    x_adv = x.clone().detach()
    if random_start == True:
        x_adv = x_adv + torch.empty_like(x_adv).uniform_(-eps, eps)
        x_adv = torch.clamp(x_adv, 0, 1)

    for _ in range(num_steps):
        x_adv.requires_grad_(True)
        loss = F.cross_entropy(model(x_adv), y)
        model.zero_grad()
        loss.backward()
        grad = x_adv.grad
        with torch.no_grad():
            x_adv = x_adv + step_size * torch.sign(grad)
            delta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + delta, 0, 1).detach()
    return x_adv

def standard_training(model, optimizer, data, target):
    optimizer.zero_grad()
    loss = F.cross_entropy(model(data), target)
    loss.backward()
    optimizer.step()
    return loss.item()

def fgsm_training(model, optimizer, data, target, eps=0.3):
    optimizer.zero_grad()
    x_adv = fgsm_attack(model, data, target, eps)
    alpha = 0.5
    loss = alpha * F.cross_entropy(model(data), target) + (1-alpha) * F.cross_entropy(model(x_adv), target)
    loss.backward()
    optimizer.step()
    return loss.item()

def pgd_training(model, optimizer, data, target, eps=0.3, step_size=0.01, num_steps=5):
    optimizer.zero_grad()
    x_adv = pgd_attack(model, data, target, eps, step_size, num_steps)
    loss = F.cross_entropy(model(x_adv), target)
    loss.backward()
    optimizer.step()
    return loss.item()

def trades_training(model, optimizer, data, target, beta=6.0, eps=0.3, step_size=0.01, num_steps=5):
    optimizer.zero_grad()
    logits_nat = model(data)
    natural_loss = F.cross_entropy(logits_nat, target)
    x_adv = pgd_attack(model, data, target, eps, step_size, num_steps)
    logits_adv = model(x_adv)
    robust_loss = F.kl_div(F.log_softmax(logits_adv, dim=1), 
                           F.softmax(logits_nat, dim=1), reduction='batchmean')
    total_loss = natural_loss + beta * robust_loss
    total_loss.backward()
    optimizer.step()
    return total_loss.item(), natural_loss.item(), robust_loss.item()

def evaluate_model(model, test_loader, attack_method=None, eps=0.3):
    model.eval()
    clean_correct, adv_correct, total = 0, 0, 0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        total += target.size(0)
        # Standard accuracy
        with torch.no_grad():
            _, prediction = torch.max(model(data).data, 1)
            clean_correct += (prediction == target).sum().item()
        # Adversarial accuracy
        if attack_method == 'fgsm':
            adv_data = fgsm_attack(model, data, target, eps)
        elif attack_method == 'pgd':
            adv_data = pgd_attack(model, data, target, eps, step_size=0.01, num_steps=5)
        else: 
            adv_data = data

        with torch.no_grad():
            _, predicted_adv = torch.max(model(adv_data).data, 1)
            adv_correct += (predicted_adv == target).sum().item()

    clean_acc, adv_acc = 100 * clean_correct / total, 100 * adv_correct / total
    return clean_acc, adv_acc    

def train_and_compare(num_epochs=3):
    methods = ['standard', 'fgsm', 'pgd', 'trades']
    results = {method: {'clean_acc': [], 'pgd_acc': [], 'fgsm_acc': [], 'loss': []} for method in methods}

    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=256, shuffle=False)

    for method in methods:
        print(f"\n === Training using: {method} === \n")
        model = SimpleCNN().to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0
            num_batches = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                if method == 'standard':
                    loss = standard_training(model, optimizer, data, target)
                elif method == 'fgsm':
                    loss = fgsm_training(model, optimizer, data, target)
                elif method == 'pgd':
                    loss = pgd_training(model, optimizer, data, target)
                elif method == 'trades':
                    loss, _, _ = trades_training(model, optimizer, data, target)
                epoch_loss += loss
                num_batches += 1

            clean_acc, pgd_acc = evaluate_model(model, test_loader, 'pgd')
            _, fgsm_acc = evaluate_model(model, test_loader, 'fgsm') 
            results[method]['clean_acc'].append(clean_acc)
            results[method]['pgd_acc'].append(pgd_acc)
            results[method]['fgsm_acc'].append(fgsm_acc)
            results[method]['loss'].append(epoch_loss/num_batches)
            print(f'Epoch {epoch+1}/{num_epochs}: Clean acc: {np.round(clean_acc, 3)} | PGD acc: {np.round(pgd_acc, 3)} | FGSM acc: {np.round(fgsm_acc, 3)} | Loss: {np.round(epoch_loss / num_batches, 3)}')

    return results, methods

def compare_methods(results, methods, num_epochs):
    print("\n FINAL RESULTS \n")
    for method in methods:
        clean_acc = results[method]['clean_acc'][-1]
        pgd_acc = results[method]['pgd_acc'][-1]
        fgsm_acc = results[method]['fgsm_acc'][-1]
        robustness_gap = clean_acc - pgd_acc
        print(f"{method.upper()}: clean accuracy: {np.round(clean_acc, 3)}, pgd_accuracy: {np.round(pgd_acc, 3)}, fgsm_acc: {np.round(fgsm_acc, 3)}, robustness_gap: {np.round(robustness_gap, 3)}")


if __name__ == "__main__":
    num_epochs = 3
    
    print("Training models with different defense methods...")
    results, methods = train_and_compare(num_epochs)
    
    # Create plots
    compare_methods(results, methods, num_epochs)

Training models with different defense methods...

 === Training using: standard === 

Epoch 1/3: Clean acc: 97.73 | PGD acc: 82.03 | FGSM acc: 4.22 | Loss: 0.255
Epoch 2/3: Clean acc: 98.5 | PGD acc: 85.44 | FGSM acc: 9.81 | Loss: 0.06
Epoch 3/3: Clean acc: 98.94 | PGD acc: 85.75 | FGSM acc: 6.11 | Loss: 0.042

 === Training using: fgsm === 

Epoch 1/3: Clean acc: 97.25 | PGD acc: 88.82 | FGSM acc: 71.66 | Loss: 0.817
Epoch 2/3: Clean acc: 97.88 | PGD acc: 84.52 | FGSM acc: 80.51 | Loss: 0.394
Epoch 3/3: Clean acc: 98.61 | PGD acc: 75.61 | FGSM acc: 84.76 | Loss: 0.268

 === Training using: pgd === 

Epoch 1/3: Clean acc: 97.64 | PGD acc: 93.86 | FGSM acc: 43.66 | Loss: 0.519
Epoch 2/3: Clean acc: 98.22 | PGD acc: 94.92 | FGSM acc: 47.01 | Loss: 0.188
Epoch 3/3: Clean acc: 98.54 | PGD acc: 95.96 | FGSM acc: 51.8 | Loss: 0.146

 === Training using: trades === 

Epoch 1/3: Clean acc: 97.51 | PGD acc: 94.58 | FGSM acc: 59.92 | Loss: 0.519
Epoch 2/3: Clean acc: 97.98 | PGD acc: 95.64 | FG