In [1]:
from google.colab import drive

drive.mount('/content/gdrive')



Mounted at /content/gdrive


In [2]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
torch.manual_seed(2)
np.random.seed(2)
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
import torchvision.models as models
!pip install torchinfo
!pip install torchattacks
!pip install pip install grad-cam
from torchinfo import summary
from torchvision.models.resnet import _resnet,BasicBlock
import torchattacks
import torchvision.utils
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM

Collecting torchinfo
  Downloading torchinfo-1.6.3-py3-none-any.whl (20 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.3
Collecting torchattacks
  Downloading torchattacks-3.2.4-py3-none-any.whl (102 kB)
[K     |████████████████████████████████| 102 kB 6.8 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.2.4
Collecting install
  Downloading install-1.3.5-py3-none-any.whl (3.2 kB)
Collecting grad-cam
  Downloading grad-cam-1.3.7.tar.gz (4.5 MB)
[K     |████████████████████████████████| 4.5 MB 7.1 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ttach
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517) ... [?25l[?25hdone
  Created wheel for grad-cam: filename=grad_cam-1.3.7-py3-none-a

In [3]:
transform = transforms.Compose([
    # transforms.Resize((224)),
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2,pin_memory=True)

valset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(
    valset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(valset.data[5000:10000]),pin_memory=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[5000:10000]),pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [4]:
from torch.nn import Conv2d,AvgPool2d,Linear,Sequential,Dropout,BatchNorm2d,ModuleList,BatchNorm1d
import torch.nn.functional as F
import numpy as np
import math
from torch.autograd import Variable

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

class Base(nn.Module):
    def freeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = False
    
    def unfreeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = True
    
    def attach_fea_out(self,classname,input,output):
        self.features.append(output)

    def attach_fea_in(self,classname,input,output):
        self.features.append(input[0])

    def __init__(self,trainable = True,attention=False,base=18):
        super(Base,self).__init__()
        self.features = []
        self.channel_size = []
        print(base)
        if base == 9:
            self.base_model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
        elif base==18:
            self.base_model = models.resnet18(pretrained=False)
        else:
            self.base_model = models.resnet34(pretrained=False)

        used_blocks = ['layer1', 'layer2','layer3','layer4']
        unused_blocks = ['avgpool','fc']

        for block in used_blocks:
            getattr(self.base_model,block).register_forward_hook(self.attach_fea_out)

        for block in unused_blocks:
             setattr(self.base_model,block,nn.Identity())
        
        if not trainable:
            self.freeze()

        fake_img = torch.rand(1,3,256,256) ## pass fake img to the model to get the channel size of each inception block
        self.base_model(fake_img)
        self.channel_size = [block.size()[1] for block in self.features]
        self.features = []

    def forward(self,img):
        self.base_model(img)

    def get_MLSP(self,img,feature_type,resize = True):
        self.base_model(img)
        if resize:
            print(resize)
            if feature_type == 'narrow':
                MLSP = [F.adaptive_avg_pool2d(block, (1, 1)) for block in self.features]
                for i in range(len(MLSP)):
                    MLSP[i] = MLSP[i].squeeze(2).squeeze(2)

            if feature_type == 'wide':
                MLSP = [F.interpolate(block,mode = 'bilinear', size = 7) for block in self.features]
            
            MLSP = torch.cat(MLSP,dim = 1)
            self.features = []
        else:
            MLSP = self.features
            self.features = []
        return MLSP



class Head(nn.Module):
    def conv_block(self,inc,outc,ker,padding = 1,avgpool = False):
        modules = []
        modules.append(nn.Dropout(0.5))
        if avgpool:
            modules.append(AvgPool2d(3,1,1))
        modules.append(Conv2d(inc,outc,ker,padding = padding))
        modules.append(nn.BatchNorm2d(outc))
        modules.append(nn.ReLU())
        return Sequential(*modules)

    def __init__(self,head_type,num_channel):
        super(Head, self).__init__()
        self.head_type = head_type
        self.num_ch = num_channel
        if head_type == 'mlsp_cnn_gap_attn':
            self.attn = []
            self.conv = []
            for i in range(4):
                if i!=3:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16))
                else:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16,no_spatial=True))
                self.conv.append(Sequential(
                                    self.conv_block(num_channel[i],num_channel[i],1,0),
                                    self.conv_block(num_channel[i],num_channel[i],3,1),
                          ))
            self.attn = ModuleList(self.attn)
            self.conv = ModuleList(self.conv)
        self.dense = Sequential(Linear(960,10))

    def forward(self,features):
        if self.head_type == 'mlsp_gap':
            x = torch.cat([F.adaptive_avg_pool2d(feature, (1, 1)) for feature in features],dim=1)
        else:
            x = torch.cat([F.adaptive_avg_pool2d(block2(block1(feature)+feature),(1,1)) for feature,block1,block2 in zip(features,self.attn,self.conv)],dim=1)
        x = torch.flatten(x, 1)
        x = self.dense(x)
        return x

class Fmodel(nn.Module):
    def __init__(self, head_type='mlsp_gap',base = 18):
        super(Fmodel,self).__init__()
        self.bmodel = Base(base=base)
        self.head = Head(head_type,self.bmodel.channel_size)
        self.feature_type = 'narrow'    
        self.resize = False
        self.fea = []
        self.gap_fea = []
        self.gradient = []
        self.handles = []
        
    def forward(self,img):
        x = self.bmodel.get_MLSP(img,self.feature_type,self.resize)
        x = self.head(x)
        return x

    def unfreeze(self):
        self.bmodel.unfreeze()
    
    def freeze(self):
        self.bmodel.freeze()

    def hook_gap(self):
        handle = self.head.dense.register_forward_hook(lambda layer, inl, _,: self.gap_fea.append(inl[0]))
        self.handles += [handle]
        return handle

    def hook_grad(self):
        handle = []
        handle.append(self.head.conv[0].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[1].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[2].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[3].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        self.handles += handle
        return handle

    def hook_fea(self):
        handle = []
        handle.append(self.head.conv[0].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[1].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[2].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[3].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        self.handles += handle
        return handle

    def hook(self):
        self.hook_gap()
        self.hook_grad()
        self.hook_fea()

    def unhook(self):
        for fea in self.gap_fea:
            fea.detach()
        for grad in self.gradient:
            grad.detach()
        for fea in self.fea:
            fea.detach()

        self.gap_fea = []
        self.gradient = []
        self.fea = []

        for h in self.handles:
            h.remove()
        self.handles = []



This will be the 3 model configs that you should train and generate the attacks on.  

In [5]:
import torchattacks

def deepcloak(model,mask_sz):
    global fea
    mask = torch.zeros((mask_sz))
    # mask.require_grad = False

    atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)
    #atk = torchattacks.FGSM(model, eps=8/255)
    
    for inputs,labels in train_loader:
        inputs_adv = atk(inputs, labels)
        inputs_adv = transforms.functional.normalize(inputs_adv,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        fea = []
        model(inputs_adv)
        feas_adv = fea[0]
        model(inputs)
        feas = fea[1]
        mask += torch.abs(feas_adv-feas).sum(0).detach().cpu()
        for f in fea:
            f.detach()
    return mask


In [None]:
model = models.resnet18(pretrained=False)
model.fc = Linear(512,10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class mask_fc(nn.Module):
    def __init__(self,fc,mask = torch.ones((512))):
        super(mask_fc,self).__init__()
        self.mask = mask.to(device)
        self.fc = copy.deepcopy(fc)
        
    def forward(self,x):
        return self.fc(self.mask*x)


model.load_state_dict(torch.load('/content/gdrive/MyDrive/cifar10/new_saved_model/resnet18_base_0')) 
model1 = copy.deepcopy(model)
model1.to(device)
model.to(device)
fea = []
# model.head.dense.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
# mask_sz = 960
model.fc.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
mask_sz = 512
mask = deepcloak(model,mask_sz)


In [None]:
percent = 0.05
null_ind = torch.topk(mask,int(percent*mask_sz))[1]
new_mask = torch.ones((mask_sz))
new_mask[null_ind] = 0

In [None]:
model.fc = mask_fc(model1.fc,mask = new_mask.to(device))

In [None]:
correct = 0 
size = 0
model.eval().cuda()

with  torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        # Make predictions.
        prediction= model(inputs)
        # prediction= model(inputs,new_mask.to(device))

        # Retrieve predictions indexes.
        _, predicted_class = torch.max(prediction.data, 1)

        # Compute number of correct predictions.
        correct += (predicted_class == labels).float().sum().item()
        size+=len(prediction)
test_accuracy = correct / size
print('Test accuracy: {}'.format(test_accuracy))


model.eval()
atks = [
    torchattacks.FGSM(model1, eps=8/255),
    torchattacks.PGD(model1, eps=8/255, alpha=2/225, steps=7, random_start=True),
]
for i in [0,1]:
    correct = 0
    start = time.time()
    size = 0
    for images, labels in test_loader:   
        adv_images = atks[i](images, labels)
        adv_images = transforms.functional.normalize(adv_images,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        labels = labels.to(device)
        outputs = model(adv_images)
        # outputs = model(adv_images,new_mask.to(device))
        _, pre = torch.max(outputs.data, 1)
        correct += (pre == labels).float().sum().item()
        size+=len(labels)

    # print('Total elapsed time (sec): %.2f' % (time.time() - start))
    print('Robust accuracy: %.2f ' % (correct / size))


Test accuracy: 0.7452
Robust accuracy: 0.38 
Robust accuracy: 0.36 


In [11]:
for i in range(4):
    model = models.resnet18(pretrained=False)
    model.fc = Linear(512,10)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    class mask_fc(nn.Module):
        def __init__(self,fc,mask = torch.ones((512))):
            super(mask_fc,self).__init__()
            self.mask = mask.to(device)
            self.fc = copy.deepcopy(fc)
            
        def forward(self,x):
            return self.fc(self.mask*x)


    model.load_state_dict(torch.load(f'/content/gdrive/MyDrive/cifar10/new_saved_model/resnet18_base_{i}')) 
    model1 = copy.deepcopy(model)
    model1.to(device)
    model.to(device)
    fea = []
    # model.head.dense.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
    # mask_sz = 960
    model.fc.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
    mask_sz = 512
    mask = deepcloak(model,mask_sz)

    percent = 0.05
    null_ind = torch.topk(mask,int(percent*mask_sz))[1]
    new_mask = torch.ones((mask_sz))
    new_mask[null_ind] = 0
    model.fc = mask_fc(model1.fc,mask = new_mask.to(device))

    correct = 0 
    size = 0
    model.eval().cuda()

    with  torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            # Make predictions.
            prediction= model(inputs)
            # prediction= model(inputs,new_mask.to(device))

            # Retrieve predictions indexes.
            _, predicted_class = torch.max(prediction.data, 1)

            # Compute number of correct predictions.
            correct += (predicted_class == labels).float().sum().item()
            size+=len(prediction)

    test_accuracy = correct / size
    print('Test accuracy: {}'.format(test_accuracy))


    model.eval()
    atks = [
        torchattacks.FGSM(model, eps=8/255),
        torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True),
    ]
    for i in [0,1]:
        correct = 0
        start = time.time()
        size = 0
        for images, labels in test_loader:   
            adv_images = atks[i](images, labels)
            adv_images = transforms.functional.normalize(adv_images,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            labels = labels.to(device)
            outputs = model(adv_images)
            # outputs = model(adv_images,new_mask.to(device))
            _, pre = torch.max(outputs.data, 1)
            correct += (pre == labels).float().sum().item()
            size+=len(labels)

        # print('Total elapsed time (sec): %.2f' % (time.time() - start))
        print('Robust accuracy: %.2f ' % (correct / size))


Test accuracy: 0.741
Robust accuracy: 0.33 
Robust accuracy: 0.30 
Test accuracy: 0.715
Robust accuracy: 0.35 
Robust accuracy: 0.32 
Test accuracy: 0.7358
Robust accuracy: 0.32 
Robust accuracy: 0.28 
Test accuracy: 0.738
Robust accuracy: 0.33 
Robust accuracy: 0.30 


In [10]:
np.array([0.7432,0.696,0.7376,0.7368]).mean(), \
np.array([0.38 ,0.37, 0.33,0.35]).mean(), \
np.array([0.36,0.36,0.29,0.33]).mean()

(0.7284, 0.35750000000000004, 0.335)