# Adversarial Training of ResNet and SparseResNet

In [1]:
import sys, os
d = os.path.dirname(os.getcwd())
sys.path.insert(0, d)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

from utils.train_utils import adv_train
from models.resnet import ResNet, SparseResNet
from utils.attacks import pgd

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [4]:
tr_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
vl_transform = transforms.Compose([
    transforms.ToTensor(),
])

ds = CIFAR10('../data', train=True, download=True, transform=tr_transform)
ds_test = CIFAR10('../data', train=False, download=True, transform=vl_transform)

batch_size = 200
train_dl = DataLoader(ds, batch_size, shuffle=True)
valid_dl = DataLoader(ds_test, batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
attack_params = {
    "epsilon": 4/225, 
    "alpha": 2/255,
    "num_iter": 10
}

# ResNet

In [6]:
model = ResNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', min_lr=1e-5, 
                                                 patience=2, verbose=True)

adv_train(model, train_dl, valid_dl, pgd, attack_params, optimizer, loss_fn,
          scheduler, epochs=20, device=device)
torch.save(model.state_dict(), "../saved/resnet_robust_eps4.pt")

Epoch: 1 Validation Loss: 1.7101 accuracy: 0.3808, time: 0:03:54
Epoch: 2 Validation Loss: 1.4066 accuracy: 0.4875, time: 0:03:57
Epoch: 3 Validation Loss: 1.2408 accuracy: 0.5439, time: 0:03:57
Epoch: 4 Validation Loss: 1.1953 accuracy: 0.5741, time: 0:03:57
Epoch: 5 Validation Loss: 1.0910 accuracy: 0.6080, time: 0:03:57
Epoch: 6 Validation Loss: 0.9772 accuracy: 0.6705, time: 0:03:58
Epoch: 7 Validation Loss: 0.9399 accuracy: 0.6766, time: 0:03:58
Epoch: 8 Validation Loss: 0.8421 accuracy: 0.7196, time: 0:03:58
Epoch: 9 Validation Loss: 0.7945 accuracy: 0.7414, time: 0:03:58
Epoch: 10 Validation Loss: 0.7569 accuracy: 0.7546, time: 0:03:57
Epoch: 11 Validation Loss: 0.7648 accuracy: 0.7575, time: 0:03:58
epochs_no_improve: 1/4
Epoch: 12 Validation Loss: 0.7441 accuracy: 0.7513, time: 0:03:57
Epoch: 13 Validation Loss: 0.6921 accuracy: 0.7779, time: 0:03:57
Epoch: 14 Validation Loss: 0.6733 accuracy: 0.7799, time: 0:03:58
Epoch: 15 Validation Loss: 0.6595 accuracy: 0.7850, time: 0:03

# SparseResNet

In [7]:
model = SparseResNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', min_lr=1e-5, 
                                                 patience=2, verbose=True)

adv_train(model, train_dl, valid_dl, pgd, attack_params, optimizer, 
          loss_fn, scheduler, epochs=20, sparse=True, device=device)
torch.save(model.state_dict(), "../saved/sparse_resnet_robust_eps4.pt")

Epoch: 1 Validation Loss: 1.5828 accuracy: 0.4041, time: 0:05:50
Epoch: 2 Validation Loss: 1.3126 accuracy: 0.5161, time: 0:05:50
Epoch: 3 Validation Loss: 1.1943 accuracy: 0.5768, time: 0:05:50
Epoch: 4 Validation Loss: 1.0643 accuracy: 0.6361, time: 0:05:49
Epoch: 5 Validation Loss: 0.9577 accuracy: 0.6755, time: 0:05:49
Epoch: 6 Validation Loss: 0.8770 accuracy: 0.7021, time: 0:05:50
Epoch: 7 Validation Loss: 0.8220 accuracy: 0.7238, time: 0:05:51
Epoch: 8 Validation Loss: 0.7670 accuracy: 0.7442, time: 0:05:49
Epoch: 9 Validation Loss: 0.6985 accuracy: 0.7651, time: 0:05:50
Epoch: 10 Validation Loss: 0.6933 accuracy: 0.7679, time: 0:05:51
Epoch: 11 Validation Loss: 0.6418 accuracy: 0.7899, time: 0:05:49
Epoch: 12 Validation Loss: 0.6125 accuracy: 0.7968, time: 0:05:49
Epoch: 13 Validation Loss: 0.5853 accuracy: 0.8065, time: 0:05:48
Epoch: 14 Validation Loss: 0.5927 accuracy: 0.7998, time: 0:05:49
epochs_no_improve: 1/4
Epoch: 15 Validation Loss: 0.5612 accuracy: 0.8158, time: 0:05