# Adversarial Training of ResNet and SparseResNet

In [1]:
import sys, os
d = os.path.dirname(os.getcwd())
sys.path.insert(0, d)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

from utils.train_utils import adv_train
from models.resnet import ResNet, SparseResNet
from utils.attacks import pgd

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [4]:
tr_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
vl_transform = transforms.Compose([
    transforms.ToTensor(),
])

ds = CIFAR10('../data', train=True, download=True, transform=tr_transform)
ds_test = CIFAR10('../data', train=False, download=True, transform=vl_transform)

batch_size = 200
train_dl = DataLoader(ds, batch_size, shuffle=True)
valid_dl = DataLoader(ds_test, batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
attack_params = {
    "epsilon": 8/225, 
    "alpha": 2/255,
    "num_iter": 10
}

# ResNet

In [6]:
model = ResNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', min_lr=1e-5, 
                                                 patience=2, verbose=True)

adv_train(model, train_dl, valid_dl, pgd, attack_params, optimizer, loss_fn,
          scheduler, epochs=20, device=device)
torch.save(model.state_dict(), "../saved/resnet_robust_eps8.pt")

Epoch: 1 Validation Loss: 1.7625 accuracy: 0.3541, time: 0:06:19
Epoch: 2 Validation Loss: 1.6179 accuracy: 0.4141, time: 0:07:31
Epoch: 3 Validation Loss: 1.5377 accuracy: 0.4515, time: 0:08:10
Epoch: 4 Validation Loss: 1.4002 accuracy: 0.5042, time: 0:08:09
Epoch: 5 Validation Loss: 1.3941 accuracy: 0.5346, time: 0:08:09
Epoch: 6 Validation Loss: 1.2672 accuracy: 0.5629, time: 0:08:10
Epoch: 7 Validation Loss: 1.1946 accuracy: 0.5653, time: 0:08:11
Epoch: 8 Validation Loss: 1.1256 accuracy: 0.6137, time: 0:08:11
Epoch: 9 Validation Loss: 1.1271 accuracy: 0.6239, time: 0:06:42
epochs_no_improve: 1/4
Epoch: 10 Validation Loss: 1.0724 accuracy: 0.6427, time: 0:08:10
Epoch: 11 Validation Loss: 1.0670 accuracy: 0.6489, time: 0:05:19
Epoch: 12 Validation Loss: 1.0112 accuracy: 0.6610, time: 0:03:57
Epoch: 13 Validation Loss: 0.9697 accuracy: 0.6912, time: 0:03:58
Epoch: 14 Validation Loss: 0.9708 accuracy: 0.6842, time: 0:03:57
epochs_no_improve: 1/4
Epoch: 15 Validation Loss: 0.9269 accur

# SparseResNet

In [7]:
model = SparseResNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', min_lr=1e-5, 
                                                 patience=2, verbose=True)

adv_train(model, train_dl, valid_dl, pgd, attack_params, optimizer, 
          loss_fn, scheduler, epochs=20, sparse=True, device=device)
torch.save(model.state_dict(), "../saved/sparse_resnet_robust_eps8.pt")

Epoch: 1 Validation Loss: 1.7219 accuracy: 0.3552, time: 0:05:50
Epoch: 2 Validation Loss: 1.4862 accuracy: 0.4510, time: 0:06:33
Epoch: 3 Validation Loss: 1.3696 accuracy: 0.5237, time: 0:05:49
Epoch: 4 Validation Loss: 1.1961 accuracy: 0.5626, time: 0:05:50
Epoch: 5 Validation Loss: 1.0965 accuracy: 0.6089, time: 0:05:50
Epoch: 6 Validation Loss: 1.0557 accuracy: 0.6313, time: 0:05:51
Epoch: 7 Validation Loss: 0.9864 accuracy: 0.6513, time: 0:05:51
Epoch: 8 Validation Loss: 0.9612 accuracy: 0.6690, time: 0:05:51
Epoch: 9 Validation Loss: 0.9398 accuracy: 0.6692, time: 0:05:51
Epoch: 10 Validation Loss: 0.8718 accuracy: 0.7021, time: 0:05:51
Epoch: 11 Validation Loss: 0.9175 accuracy: 0.6930, time: 0:05:51
epochs_no_improve: 1/4
Epoch: 12 Validation Loss: 0.8842 accuracy: 0.6991, time: 0:05:51
epochs_no_improve: 2/4
Epoch: 13 Validation Loss: 0.8559 accuracy: 0.7193, time: 0:05:51
Epoch: 14 Validation Loss: 0.8577 accuracy: 0.7215, time: 0:05:51
epochs_no_improve: 1/4
Epoch: 15 Valida