In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

!mkdir /content/data
%cd /content/data
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip 

Mounted at /content/gdrive
/content/data
--2021-10-07 12:49:04--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.68.10
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2021-10-07 12:49:20 (15.2 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]



In [None]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
torch.manual_seed(2)
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.models as models
!pip install torchinfo
!pip install torchattacks
from torchinfo import summary
from torchvision.models.resnet import _resnet,BasicBlock

import torchvision.utils

Collecting torchinfo
  Downloading torchinfo-1.5.3-py3-none-any.whl (19 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.5.3
Collecting torchattacks
  Downloading torchattacks-3.2.0-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 8.3 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.2.0


In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
transform = transforms.Compose([ 
    transforms.Resize((224)), 
    transforms.ToTensor(),
    normalize
])

dataset = torchvision.datasets.ImageFolder('/content/data/tiny-imagenet-200/train', transform=transform)
#split the data
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [80000, 10000, 10000], generator=torch.Generator().manual_seed(42))  ##set seed to ensure consistency
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=2, sampler = SubsetRandomSampler(range(80000)))
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False,num_workers=2)

In [None]:
transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform1)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

valset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform2)
val_loader = torch.utils.data.DataLoader(
    valset, batch_size=100, shuffle=False, num_workers=2, sampler = SubsetRandomSampler(range(5000)))

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform2)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2, sampler = SubsetRandomSampler(range(5000,10000)))

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [None]:
from torch.nn import Conv2d,AvgPool2d,Linear,Sequential,Dropout,BatchNorm2d,ModuleList,BatchNorm1d
import torch.nn.functional as F
import numpy as np
import math
from torch.autograd import Variable

class Base(nn.Module):
    def freeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = False
    
    def unfreeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = True
    
    def attach_fea_out(self,classname,input,output):
        self.features.append(output)

    def attach_fea_in(self,classname,input,output):
        self.features.append(input[0])

    def __init__(self,trainable = True,attention=False):
        super(Base,self).__init__()
        self.features = []
        self.channel_size = []
        
        self.base_model = models.resnet18(pretrained=False)
        used_blocks = ['layer1', 'layer2','layer3','layer4']
        unused_blocks = ['avgpool','fc']

        for block in used_blocks:
            getattr(self.base_model,block).register_forward_hook(self.attach_fea_out)


        for block in unused_blocks:
             setattr(self.base_model,block,nn.Identity())
        
        if not trainable:
            self.freeze()

        fake_img = torch.rand(1,3,256,256) ## pass fake img to the model to get the channel size of each inception block
        self.base_model(fake_img)
        self.channel_size = [block.size()[1] for block in self.features]
        self.features = []

    def forward(self,img):
        self.base_model(img)

    def get_MLSP(self,img,feature_type,resize = True):
        self.base_model(img)
        if resize:
            print(resize)
            if feature_type == 'narrow':
                MLSP = [F.adaptive_avg_pool2d(block, (1, 1)) for block in self.features]
                for i in range(len(MLSP)):
                    MLSP[i] = MLSP[i].squeeze(2).squeeze(2)

            if feature_type == 'wide':
                MLSP = [F.interpolate(block,mode = 'bilinear', size = 7) for block in self.features]
            
            MLSP = torch.cat(MLSP,dim = 1)
            self.features = []
        else:
            MLSP = self.features
            self.features = []
        return MLSP


class head_block():
    def conv_block(self,inc,outc,ker,padding = 1,avgpool = False):
        modules = []
        modules.append(nn.Dropout(0.5))
        if avgpool:
            modules.append(AvgPool2d(3,1,1))
        modules.append(Conv2d(inc,outc,ker,padding = padding))
        modules.append(nn.BatchNorm2d(outc))
        modules.append(nn.ReLU())
        return Sequential(*modules)

    def mlsp_cnn_gap_attn(self,num_channels,attention=True):
        blocks = []
        scale = 1
        all_channels = np.sum(num_channels)
        for num_channel in num_channels:
            blocks.append(Sequential(#self.conv_block(num_channel,num_channel//scale,1,0),
                                     #self.conv_block(num_channel//scale,num_channel//scale,3,1),
                                     CBAM(num_channel//scale,reduction_ratio=16)
                          ))

        return ModuleList(blocks)


class Head(nn.Module):
    def __init__(self,head_type,num_channel):
        super(Head, self).__init__()
        if head_type == 'mlsp_cnn_gap_attn':
            self.head = getattr(head_block(),head_type)(num_channel)
        self.head_type = head_type
        self.num_ch = num_channel
        self.dense = Sequential(Linear(960,1000))
    def forward(self,features,masks = None):
        if self.head_type == 'mlsp_gap':
            x = torch.cat([F.adaptive_avg_pool2d(feature, (1, 1)) for feature in features],dim=1)
        else:
            x = torch.cat([F.adaptive_avg_pool2d(block(feature)+feature, (1, 1)) for feature,block in zip(features,self.head)],dim=1)
        x = torch.flatten(x, 1)
        # if masks is not None:
        #     if isinstance(feas,list):
        #         for i,(fea,mask) in enumerate(zip(feas,masks)):
        #             print(fea.shape,mask.shape)
        #             feas[i] = mask*fea
        #     else:
        #         feas = mask*feas
        if masks is not None:
            x = x*masks
        x = self.dense(x)
        return x

class Fmodel(nn.Module):
    def __init__(self, head_type):
        super(Fmodel,self).__init__()
        self.bmodel = Base()
        self.head = Head(head_type,self.bmodel.channel_size)
        self.feature_type = 'narrow'    
        self.resize = False
        
    def forward(self,img, masks = None):
        x = self.bmodel.get_MLSP(img,self.feature_type,self.resize)
        x = self.head(x,masks)
        return x

    def unfreeze(self):
        self.bmodel.unfreeze()
    
    def freeze(self):
        self.bmodel.freeze()




This will be the 3 model configs that you should train and generate the attacks on.  

In [None]:
model = Fmodel('mlsp_gap')
model(torch.rand((2,3,224,224)))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# model.load_state_dict(torch.load('/content/gdrive/MyDrive/cifar10/saved_model/mlsp_base.pt'))
model.load_state_dict(torch.load('/content/gdrive/MyDrive/imagenet/saved_model/mlsp_gap_attack.pt'))

RuntimeError: ignored

In [None]:
model = models.resnet34(pretrained=False)
# model.fc = Linear(512,10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class mask_fc(nn.Module):
    def __init__(self,fc,mask = torch.ones((512))):
        super(mask_fc,self).__init__()
        self.mask = mask.to(device)
        self.fc = copy.deepcopy(fc)
        
    def forward(self,x):
        return self.fc(self.mask*x)

model.load_state_dict(torch.load('/content/gdrive/MyDrive/imagenet/saved_model/resnet34_attack.pt')) 
model1 = copy.deepcopy(model)
model1.to(device)
model.to(device)
print('')




In [None]:
!ls /content/gdrive/MyDrive/imagenet/saved_model/

attk				  latest_baseline_vgg16.pt
best_resnet18.pt		  latest_resnet18_custom_attention.pt
custom_attention.pt		  mlsp_adv_atk_alt1.pt
custom_attn_multi_12M_noskip.pt   mlsp_adv_attk_alt_full_imgnet1.pt
custom_attn_multi_18M_no_skip.pt  mlsp_adv_attk_alt_full_imgnet.pt
custom_attn_multi_18M.pt	  mlsp_adv_attk_alt.pt
custom_attn_multi.pt		  mlsp_gap_attack.pt
custom_attn_skip_16M.pt		  multi_3FC.pt
custom_no_attn_multi_18M_skip.pt  pool_3FC_5x5.pt
custom_noattn_noskip_16M.pt	  pool_3FC.pt
custom_no_attn_skip_16M.pt	  resnet18_adv_attk_alt_full_imgnet.pt
custom_only_att_3convfinal.pt	  resnet18_adv_attk_alt.pt
custom_only_att.pt		  resnet18_adv_train.pt
custom.pt			  resnet18_final_custom_attention.pt
custom_skip.pt			  resnet34_adv_attk_alt_full_imgnet.pt
custom_stepwise.pt		  resnet34_attack.pt
Fmodel_mlsp_adverserial.pt	  resnet9_adv_attk_alt.pt
Fmodel_mlsp_cnn_gap_attn.pt	  resnet_attack_adverserial.pt
InceptionV3.pt			  resnet_attack.pt
InceptionV3true.pt


In [None]:
import torchattacks

def deepcloak(model,mask_sz):
    global fea
    mask = torch.zeros((mask_sz))
    # mask.require_grad = False

    # atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)
    atk = torchattacks.FGSM(model, eps=8/255)
    
    for inputs,labels in train_loader:
        inputs_adv = atk(inputs, labels)
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        fea = []
        model(inputs_adv)
        feas_adv = fea[0]
        model(inputs)
        feas = fea[1]

        mask += torch.abs(feas_adv-feas).sum(0).detach().cpu()

    return mask


In [None]:
fea = []
# model.head.dense.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
# mask_sz = 960
model.fc.register_forward_hook(lambda layer, inl, _,: fea.append(inl[0].detach()))
mask_sz = 512
mask = deepcloak(model,mask_sz)

In [None]:
percent = 0
null_ind = torch.topk(mask,int(percent*mask_sz))[1]
new_mask = torch.ones((mask_sz))
new_mask[null_ind] = 0

In [None]:
model.fc = mask_fc(model1.fc,mask = new_mask.to(device))

In [None]:
correct = 0 
size = 0
model.eval().cuda()

with  torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # Make predictions.
        prediction= model(inputs)
        # prediction= model(inputs,new_mask.to(device))

        # Retrieve predictions indexes.
        _, predicted_class = torch.max(prediction.data, 1)

        # Compute number of correct predictions.
        correct += (predicted_class == labels).float().sum().item()
        size+=len(prediction)
test_accuracy = correct / size
print('Test accuracy: {}'.format(test_accuracy))


model.eval()
atks = [
    torchattacks.FGSM(model, eps=8/255),
    torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True),
]
for i in [0,1]:
    correct = 0
    start = time.time()
    size = 0
    for images, labels in test_loader:   
        adv_images = atks[i](images, labels)
        labels = labels.to(device)
        outputs = model(adv_images)
        # outputs = model(adv_images,new_mask.to(device))
        _, pre = torch.max(outputs.data, 1)
        correct += (pre == labels).float().sum().item()
        size+=len(labels)

    # print('Total elapsed time (sec): %.2f' % (time.time() - start))
    print('Robust accuracy: %.2f ' % (correct / size))


Test accuracy: 0.4685
Robust accuracy: 0.06 
Robust accuracy: 0.00 


3 types of white box attack 

In [None]:
#torch.save(model.state_dict(),'/content/gdrive/MyDrive/imagenet/saved_model/Fmodel_mlsp_adverserial.pt' )

In [None]:
#torch.save(model.state_dict(),'/content/gdrive/MyDrive/imagenet/saved_model/Fmodel_mlsp_cnn_gap_attn.pt' )

In [None]:
#model = Fmodel('mlsp_gap')
model = models.resnet18(pretrained=False)
model.to(device)
#model = torch.load('/content/gdrive/MyDrive/imagenet/saved_model/atk/resnet18_base.pt')
model.load_state_dict(torch.load('/content/gdrive/MyDrive/imagenet/saved_model/attk/resnet18_base.pt'))
#model.load_state_dict(torch.load('/content/gdrive/MyDrive/imagenet/saved_model/resnet_attack_adverserial.pt'))

<All keys matched successfully>

In [None]:
!ls /content/gdrive/MyDrive/imagenet/saved_model/attk

attk				  Fmodel_mlsp_adverserial.pt
best_resnet18.pt		  Fmodel_mlsp_cnn_gap_attn.pt
custom_attention.pt		  InceptionV3.pt
custom_attn_multi_12M_noskip.pt   InceptionV3true.pt
custom_attn_multi_18M_no_skip.pt  latest_baseline_vgg16.pt
custom_attn_multi_18M.pt	  latest_resnet18_custom_attention.pt
custom_attn_multi.pt		  mlsp_gap_attack.pt
custom_attn_skip_16M.pt		  multi_3FC.pt
custom_no_attn_multi_18M_skip.pt  pool_3FC_5x5.pt
custom_noattn_noskip_16M.pt	  pool_3FC.pt
custom_no_attn_skip_16M.pt	  resnet18_adv_train.pt
custom_only_att_3convfinal.pt	  resnet18_final_custom_attention.pt
custom_only_att.pt		  resnet34_attack.pt
custom.pt			  resnet_attack_adverserial.pt
custom_skip.pt			  resnet_attack.pt
custom_stepwise.pt
