In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from tqdm import tqdm

from torchadv.attacks import FGSM, PGD
from torchadv.utils import get_available_device

In [2]:
train_transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)

test_transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, transform=train_transforms, download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, transform=test_transforms, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=2
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class CNN(nn.Module):
    """Basic CNN architecture."""

    def __init__(self, in_channels=1):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 8, 1)
        self.conv2 = nn.Conv2d(64, 128, 6, 2)
        self.conv3 = nn.Conv2d(128, 128, 5, 2)
        self.fc = nn.Linear(128 * 3 * 3, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 128 * 3 * 3)
        x = self.fc(x)
        return x

In [4]:
lr = 1e-3
epochs = 10
device = get_available_device()

net = CNN(in_channels=3)
net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

In [5]:
net.train()
for epoch in range(1, epochs + 1):
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}', unit='batch'):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        loss = criterion(net(x), y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {avg_train_loss:.3f}")

torch.save(net.state_dict(), 'cifar10_cnn.pth')

Epoch 1/10: 100%|██████████| 391/391 [00:30<00:00, 12.67batch/s]


Epoch 1/10: Train Loss: 1.680


Epoch 2/10: 100%|██████████| 391/391 [00:30<00:00, 12.83batch/s]


Epoch 2/10: Train Loss: 1.424


Epoch 3/10: 100%|██████████| 391/391 [00:30<00:00, 12.68batch/s]


Epoch 3/10: Train Loss: 1.303


Epoch 4/10: 100%|██████████| 391/391 [00:30<00:00, 12.85batch/s]


Epoch 4/10: Train Loss: 1.216


Epoch 5/10: 100%|██████████| 391/391 [00:30<00:00, 12.78batch/s]


Epoch 5/10: Train Loss: 1.131


Epoch 6/10: 100%|██████████| 391/391 [00:30<00:00, 12.77batch/s]


Epoch 6/10: Train Loss: 1.069


Epoch 7/10: 100%|██████████| 391/391 [00:30<00:00, 12.87batch/s]


Epoch 7/10: Train Loss: 1.007


Epoch 8/10: 100%|██████████| 391/391 [00:30<00:00, 12.72batch/s]


Epoch 8/10: Train Loss: 0.959


Epoch 9/10: 100%|██████████| 391/391 [00:30<00:00, 12.80batch/s]


Epoch 9/10: Train Loss: 0.908


Epoch 10/10: 100%|██████████| 391/391 [00:30<00:00, 12.84batch/s]

Epoch 10/10: Train Loss: 0.851





In [8]:
net = CNN(in_channels=3)
net = net.to(device)
net.load_state_dict(torch.load('cifar10_cnn.pth'))

<All keys matched successfully>

In [9]:
net.eval()
report = {'nb_test': 0, 'correct': 0, 'correct_fgsm': 0, 'correct_pgd': 0}

fgsm = FGSM(net)
pgd = PGD(net)

for x, y in tqdm(test_loader, desc='Evaluation', unit='batch'):
    x, y = x.to(device), y.to(device)
    x_fgsm = fgsm(x)
    x_pgd = pgd(x)
    
    # Model predictions
    y_pred = net(x).argmax(1)
    y_pred_fgsm = net(x_fgsm).argmax(1)
    y_pred_pgd = net(x_pgd).argmax(1)
    
    # Update report
    report['nb_test'] += y.size(0)
    report['correct'] += y_pred.eq(y).sum().item()
    report['correct_fgsm'] += y_pred_fgsm.eq(y).sum().item()
    report['correct_pgd'] += y_pred_pgd.eq(y).sum().item()

# Calculate and print accuracy
print(f"Test accuracy on clean examples (%): {report['correct'] / report['nb_test'] * 100.0:.3f}")
print(f"Test accuracy on FGSM adversarial examples (%): {report['correct_fgsm'] / report['nb_test'] * 100.0:.3f}")
print(f"Test accuracy on PGD adversarial examples (%): {report['correct_pgd'] / report['nb_test'] * 100.0:.3f}")

Evaluation: 100%|██████████| 79/79 [02:38<00:00,  2.00s/batch]

Test accuracy on clean examples (%): 61.930
Test accuracy on FGSM adversarial examples (%): 7.810
Test accuracy on PGD adversarial examples (%): 18.840



