<a href="https://colab.research.google.com/github/justinqbui/mini-advProp/blob/main/advProp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is pytorch re-implementation of the paper [Adversarial Examples Improve Image Recognition](https://arxiv.org/pdf/1911.09665.pdf).

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
!nvidia-smi

For our model, we'll define and use a simple ResNet-18, but modified with 2 separate BN layers, in order to accomodate 1 used for natural training, and the other for adversarial traning. We'll start with a ResNet stem, from [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). In the advProp paper, the authors use three auxilary BNs, one for natural, one for autoAugmented images, and one for adversarially generated images. For simplicity, we won't autoaugment any examples and stick to only two BNs, one for natural and one for adversarially generated examples.


In [4]:
class ConvStem(nn.Module):
    """
    Convolution ->Batch Norm -> GELU
    We make sure to use auxilary BN because we want the learnable parameters (gamma & beta)
    for the BN layers to be different for natural and adversarially generated examples
    """
    def __init__(self, in_channels, out_channels, kernal_size, stride, padding):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernal_size, stride, padding)
        self.clean_bn = nn.BatchNorm2d(out_channels)
        self.adv_bn = nn.BatchNorm2d(out_channels)
        self.GELU = nn.GELU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
    def forward(self, x, adv = False):
        """adv -> if the input is an adversially generated example"""
        x = self.conv(x)
        if adv:
            x = self.adv_bn(x)
        else:
            x = self.clean_bn(x)
        x = self.GELU(x)
        x = self.maxpool(x)
        return x


We've now defined our stem, next we define our ResNet blocks, which make up the body of our CNN. We again use an auxilary BN and we choose GELU as our non-linearity as suggested by [Smooth Adversarial Training](https://arxiv.org/pdf/2006.14536.pdf?ref=https://githubhelp.com), which tells as that GELU or other similar smooth activation functions make our models much more robust.


In [5]:
class Block(nn.Module):
    """
    conv-> BN -> GELU -> conv -> BN -> GELU
    We use identity_downsample when our input space dimensions change
    """
    def __init__(self, in_channels, out_channels, stride = 1, identity_downsample = None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.clean_bn1 = nn.BatchNorm2d(out_channels)
        self.adv_bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.clean_bn2 = nn.BatchNorm2d(out_channels)
        self.adv_bn2 = nn.BatchNorm2d(out_channels)
        self.GELU = nn.GELU()
        self.identity_downsample = identity_downsample
    
    def forward(self, x, adv = False):
        identity = x

        x = self.conv1(x)
        if adv:
            x = self.adv_bn1(x)
        else:
            x = self.clean_bn1(x)

        x = self.GELU(x)
        x = self.conv2(x)
        if adv:
            x = self.adv_bn2(x)
        else:
            x = self.clean_bn2(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        # residual connection
        x += identity
        x = self.GELU(x)
        return x

        


We create an identity downsample on convolutions that have stride of 2.

In [6]:
class IdentityDownsample(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride)
        self.clean_bn = nn.BatchNorm2d(out_channels)
        self.adv_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, adv = False):
        x = self.conv(x)
        if adv:
            x = self.adv_bn(x)
        else:
            x = self.clean_bn(x)
        return x
    

In [7]:
class MakeLayers(nn.Module):
    """
    A custom module used to make the blocks for the ResNet
    necessary to differentiate clean vs adversarial
    Params
    """
    def __init__(self, num_blocks, in_channels, intermediate_channels, stride):
        super().__init__()
        self.layers = nn.ModuleList()
        self.in_channels = in_channels

        identity_downsample = IdentityDownsample(self.in_channels, intermediate_channels, stride)
        # here we do the downsample at the beginning of every block
        self.layers.append(Block(self.in_channels, intermediate_channels, stride, identity_downsample))
        self.in_channels = intermediate_channels
        
        for i in range(num_blocks - 1):
            self.layers.append(Block(self.in_channels, intermediate_channels))

    
    def forward(self, x, adv = False):
        for layer in self.layers:
            x = layer(x, adv)
        return x

    def in_channels(self):
        return self.in_channels

In [8]:
class ResNet18(nn.Module):
    """
    simple class that builds out a resnet-18 with a special stem
    Params:
    n_classes -> the number of classes
    in_channels -> number of input channels (3 for (R,G, B), 1 for B,W)
    """
    def __init__(self, n_classes, in_channels = 3):
        super().__init__()
        # block size for in/out channels as the network gets deeper
        self.stem = ConvStem(in_channels, 64, 7, 2, 3)
        # we perform a downsample on the first convolution of layers that have stride == 2
        self.in_channels = 64

        self.layer1 = MakeLayers(2, self.in_channels, intermediate_channels = 64, stride = 1)
        self.in_channels = self.layer1.in_channels

        self.layer2 = MakeLayers(2, self.in_channels, intermediate_channels = 128, stride = 2)
        self.in_channels = self.layer2.in_channels

        self.layer3 = MakeLayers(2, self.in_channels, intermediate_channels = 256, stride = 2)
        self.in_channels = self.layer3.in_channels

        self.layer4 = MakeLayers(2, self.in_channels, intermediate_channels = 512, stride = 2)
        self.in_channels = self.layer4.in_channels

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, n_classes)


    
    def forward(self, x, adv = False):
        x = self.stem(x,adv)

        x = self.layer1(x, adv)
        x = self.layer2(x, adv)
        x = self.layer3(x, adv)
        x = self.layer4(x, adv)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x


In [9]:
def test():
    model = ResNet18(1000)
    y = model(torch.randn(4, 3, 224, 224))
    print(y.size())

test()


torch.Size([4, 1000])


Let's load in CIFAR10 dataset. 

In [10]:
# from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [11]:

def epoch(model, data_loader, opt = None):
    """
    Performs a standard epoch, with no data augmentation
    or adversarial training
    Params:
    model -> NN model
    data_loader -> torch.utils.Dataloader
    opt -> set to None if evaluating model on test set
    """
    total_loss, total_err = 0., 0. 
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (y_pred.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * x.shape[0]
    return total_err / len(data_loader.dataset), total_loss / len(data_loader.dataset)



In [None]:
model = ResNet18(10)
model.to(device)

In [13]:
opt = optim.SGD(model.parameters(), lr=.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')
for i in range(1,51):
    train_err, train_loss = epoch(model, trainloader, opt)
    test_err, test_loss = epoch(model, testloader)

    if (i % 2 == 0):
        print("{:2}".format(i), end = "  ")
        print(*("{:.6f}".format(j) for j in (train_err, train_loss, test_err, test_loss)), sep="\t")

 2  0.457340	1.267034	0.407500	1.141461
 4  0.314320	0.887641	0.311800	0.885461
 6  0.225040	0.642420	0.281800	0.810661
 8  0.161340	0.459199	0.266300	0.850253
10  0.109900	0.310121	0.262000	0.909521
12  0.071640	0.204877	0.261000	1.058157
14  0.049200	0.140955	0.263000	1.176733
16  0.034960	0.103427	0.254700	1.189586
18  0.025140	0.074273	0.259600	1.503075
20  0.017660	0.053157	0.256900	1.522674
22  0.015860	0.045469	0.260300	1.462262
24  0.015040	0.043997	0.259900	1.495456
26  0.010640	0.031356	0.250200	1.770264
28  0.010580	0.029678	0.257400	1.660864
30  0.009280	0.027252	0.252300	1.687404
32  0.007400	0.021710	0.253100	1.845033
34  0.007380	0.022244	0.251500	1.712113
36  0.005740	0.016448	0.257400	1.860261
38  0.005340	0.016021	0.250300	1.837050
40  0.005520	0.015997	0.252400	1.948214
42  0.005460	0.016620	0.253700	1.780878
44  0.003200	0.009034	0.248600	1.994145
46  0.003480	0.011155	0.249800	1.961073
48  0.004680	0.013547	0.252100	2.027856
50  0.004520	0.014421	0.247900	1.824296


Here we define fast gradient sign method (FGSM) as $adv_x = x + \epsilon * sgn(\Delta_xL(x,y,\theta))$.

In [14]:
def fgsm(model, x, y, eps = 0.1):
    """ 
    performs fast gradient sign method
    Params:
    model -> NN model
    x -> input image (tensor) 
    y -> corresponding label(s) for X
    eps = epsilon hyperparameter
    """
    delta = torch.zeros_like(x, requires_grad=True)
    loss = nn.CrossEntropyLoss()(model(x + delta), y)
    loss.backward()
    return eps * delta.grad.detach().sign()

We define projected gradient descent as $x^{t+1} = \Pi_{x+S}(x^t+\alpha $sgn$(\nabla_x L(x,y,\theta)))$. Even though it's called projected gradient descent, we're actually taking a step in the direction of the gradient, to move away from the local minima instead of toward.

In [15]:
def pgd_linf(model, x, y, eps=1, alpha=1, num_iter=2, randomize=False):
    """ 
    performs PGD in L_inf space
    Params:
    model -> NN model
    x -> input image (tensor) 
    y -> corresponding label(s) for X
    eps = epsilon hyperparameter
    alpha = alpha hyperparameter
    num_iters = number of iterations performed of PGD
    randomize = random restarts (helps avoid local optima that PGD can find if started at zero point)
    """
    if randomize:
        delta = torch.rand_like(x, requires_grad=True)
        delta.data = delta.data * 2 * eps - eps
    else:
        delta = torch.zeros_like(x, requires_grad=True)
        
    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()(model(x + delta), y)
        loss.backward()
        #clamp data between [-epsilion, epsilon]
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-eps,eps)
        delta.grad.zero_()
    return delta.detach()

In [None]:
adv_model = ResNet18(10)
adv_model.to(device)

The advProp algorithm is quite simple
```
for each epoch:
    x, y = clean image mini-batch, clean image label(s)
    delta = compute noise
    compute adv_loss of adversarial example
    compute clean_loss of clean example
    minimize the loss w.r.t. the min(adv_loss + clean_loss)
```

In [22]:
def advProp_epoch(model, data_loader, attack, opt = None):
    """
    implements advProp algorithm
    """
    total_clean_err, total_clean_loss = 0., 0.
    total_adv_err, total_adv_loss = 0., 0.
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        delta = attack(model, x, y)
        # grab prediction of adversarial example
        y_pred_adv = model(x + delta, adv = True)
        adv_loss = nn.CrossEntropyLoss()(y_pred_adv, y)
        y_pred_clean = model(x)
        clean_loss = nn.CrossEntropyLoss()(y_pred_clean, y)
        # compute the min loss of clean + adv
        loss = (clean_loss + adv_loss) / 2
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()

        
        total_clean_err += (y_pred_clean.max(dim=1)[1] != y).sum().item()
        total_adv_err += (y_pred_adv.max(dim=1)[1] != y).sum().item()
        total_clean_loss += clean_loss.item() * x.shape[0]
        total_adv_loss += adv_loss.item() * x.shape[0]

    return (total_clean_err / len(data_loader.dataset),
            total_adv_err / len(data_loader.dataset),
            total_clean_loss / len(data_loader.dataset),
            total_adv_loss / len(data_loader.dataset))

In [24]:
opt = optim.SGD(adv_model.parameters(), lr=.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min')
for i in range(1,51):
    clean_err, adv_err, clean_loss, adv_loss = advProp_epoch(adv_model, trainloader, pgd_linf, opt)
    test_err, test_loss = epoch(adv_model, testloader)

    if (i % 2 == 0):
        print("{:2}".format(i), end = "  ")
        print(*("{:.6f}".format(j) for j in (clean_err, clean_loss, adv_err, adv_loss, test_err, test_loss)), sep="\t")

 2  0.500740	1.405459	0.468760	1.345284	0.432800	1.262209
 4  0.325180	0.925962	0.207780	0.614954	0.326100	0.932992
 6  0.244000	0.695638	0.160400	0.479563	0.280800	0.819140
 8  0.182240	0.520222	0.173180	0.513462	0.257400	0.783842
10  0.135560	0.383466	0.207660	0.610962	0.271500	0.895839
12  0.097180	0.272971	0.241460	0.708779	0.259000	0.947442
14  0.068700	0.193662	0.260400	0.764145	0.260700	1.011175
16  0.050580	0.143295	0.265880	0.783320	0.257100	1.170417
18  0.035840	0.103049	0.265540	0.774524	0.259600	1.265479
20  0.028540	0.084470	0.268300	0.784799	0.255400	1.356969
22  0.025640	0.073284	0.269040	0.783438	0.255700	1.399395
24  0.021220	0.059996	0.270180	0.792809	0.257700	1.510917
26  0.016400	0.049188	0.267880	0.777332	0.250700	1.463987
28  0.015160	0.044483	0.258680	0.757370	0.251600	1.548814
30  0.014220	0.040806	0.258900	0.756449	0.255200	1.599421
32  0.013660	0.038880	0.256200	0.744126	0.252600	1.571965
34  0.012740	0.036826	0.252080	0.744096	0.259500	1.629597
36  0.011240	0

We can see that after training, both the adversarially trained model and the one only trained on clean images perform about the same. In the paper, only ResNet50+ larger variants were used. If we were to train both models for more epochs, I'd hypothesize that the adversarially trained model would begin to outperform the clean model, because we'd have more "training data" so to speak by generating new examples and having less issues with overfitting. We also used the optimizer proposed in the ResNet, instead of the one proposed in advProp (RMSProp).
