In [None]:
import os
import pandas as pd
# 读取图片
from PIL import Image
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import Dataset,DataLoader
# 载入pretrained model
import torchvision.models as models
# 将资料转换成为符合pretrained model模型的形式
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_path = '../input/data/images/'
label_path = '../input/data/labels.csv'
labelname_path = '../input/data/categories.csv'

## Define Dataset

In [None]:
class Adverdataset(Dataset):
    def __init__(self,root,label,transforms):
        # 图片所在文件夹
        self.root = root
        # label
        self.label = torch.from_numpy(label).long()
        # 由 Attacker 传入的 transforms 将图片转换成符合 pretrained model 的形式
        self.transforms = transforms
        # 图片文档名称的list
        self.fnames = []
        
        for i in range(200):
            self.fnames.append(f'{i:03d}.png')
        
    def __len__(self):
        return 200
    
    def __getitem__(self,index):
        # 读取图片
        img = Image.open(os.path.join(self.root,self.fnames[index]),mode='r')
        # 将输入的图片转换成符合 pretrained model 期望的形式
        img = self.transforms(img)
        label = self.label[index]
        return img,label

## Define Attacker
this model loads the pretrained model and attacks it.

In [None]:
class Attacker:
    def __init__(self,img_dir,label,device,image_path):
        self.device = device
        
        # load pretrained model vgg16
        self.model = models.vgg16(pretrained=True)
        self.model.to(device)
        
        # set model to eval mode
        self.model.eval()
        
        # Normalize 时为什么要这样来设置 mean AND std ? 这貌似是 pytorch 官方给出的 imagenet 的 normalize 范例
        self.mean = [0.485,0.456,0.406]
        self.std = [0.229,0.224,0.225]
        self.normalize = transforms.Normalize(self.mean, self.std, inplace=False)
        
        # 定义 transform 用来将数据转换成模型期望的形式
        transform = transforms.Compose([
            transforms.Resize((224,224),interpolation=3),
            transforms.ToTensor(),
            self.normalize
        ])
        
        # 读取资料
        self.dataset = Adverdataset(image_path,label,transform)
        
        self.dataloader = DataLoader(self.dataset,batch_size=1, shuffle=False)

    def unnormalize(self,data):
        out = data * torch.tensor(self.std, device=self.device).view(3,1,1) + torch.tensor(self.mean, device=self.device).view(3,1,1)
        out = out.squeeze().detach().cpu().numpy()
        return out
        
    # FGSM Attack
    def fgsm_attack(self,image,epsilon,data_grad):
        # 找出 gradient 的方向
        sign_data_grad = data_grad.sign()
        
        # 将图片加上 noise = gradient.sign * epsilon
        perturbed_image = image + epsilon * sign_data_grad
        
        return perturbed_image
    
    def top3(self,output):
        top3 = F.softmax(output).topk(3)
#         import pdb
#         pdb.set_trace()
#         top3 = list(zip(top3.indices.detach().cpu().numpy().flatten().tolist(),top3.values.detach().cpu().numpy().flatten().tolist()))
        top3 = (top3.indices.detach().cpu().numpy().flatten().tolist(),top3.values.detach().cpu().numpy().flatten().tolist())
        return top3
    
    def attack(self,epsilon):
        # 存下一些成功攻击的图片 后续展示
        adv_examples = []
        
        wrong,fail,success = 0,0,0
        for i, (data,target) in enumerate(self.dataloader):
            data, target = data.to(self.device), target.to(self.device)
            data_raw = data
            
            data.requires_grad = True
            # 将图片丢进 model 进行测试 得到相对应的 class
            output = self.model(data)
            init_pred = output.argmax()
            init_pred_top3 = self.top3(output)
            
            # 如果 class 错误 就不进行攻击
            if init_pred.item() != target.item():
                wrong += 1
                continue
            
            # 如果 class 正确 就开始计算 gradient 进行 FGSM 攻击
            loss = F.nll_loss(output,target)
            self.model.zero_grad()
            loss.backward()            
            data_grad = data.grad.data
            perturbed_data = self.fgsm_attack(data, epsilon, data_grad)

            # 再将加入 noise 的图片丢入 model 进行测试 得到相对应的 class
            output = self.model(perturbed_data)
            final_pred = output.argmax()
            final_pred_top3 = self.top3(output)
            
            if final_pred.item() == target.item():
                # 辨识结果还是正确 攻击失败
                fail += 1
            else:
                # 辨识结果失败 攻击成功
                success += 1
                # 将攻击成功的图片存入
                if len(adv_examples) < 5:
                    adv_ex = self.unnormalize(perturbed_data)
                    data_raw = self.unnormalize(data_raw)
                    adv_examples.append((init_pred.item(),final_pred.item(),data_raw,adv_ex, init_pred_top3, final_pred_top3))
                    
            final_acc = (fail / len(self.dataloader))
            print(f'Attacking {i+1} of {len(self.dataloader)}',end='\r')

        print()
        print('Epsilon: {}\t Init Accuracy = {} Test Accuracy = {} / {} = {}\n'
              .format(epsilon, (1 - wrong/len(self.dataloader)), fail, len(self.dataloader), final_acc))
           
        return adv_examples, final_acc

## Attacking

In [None]:
label = pd.read_csv(label_path).loc[:,'TrueLabel'].to_numpy()
label_name = pd.read_csv(labelname_path).loc[:,'CategoryName'].to_numpy()
attacker = Attacker(image_path,label,device,image_path)
epsilons = [0.1,0.01]

accuracies, examples = [],[]

# 开始攻击，并存储正确率和攻击成功的图片
for eps in epsilons:
    ex,acc = attacker.attack(eps)
    accuracies.append(acc)
    examples.append(ex)

## Show the pictures produced by FGSM

In [None]:
cnt = 0
plt.figure(figsize=(30,30))

for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0])*2,cnt)
        plt.xticks([],[])
        plt.yticks([],[])
        if j == 0:
            plt.ylabel(f"Eps: {epsilons[i]}",fontsize=14)
        orig,adv,orig_img,ex,orig_top3,adv_top3 = examples[i][j]
        plt.title("original: {}".format(label_name[orig].split(',')[0]))
        orig_img = np.transpose(orig_img,(1,2,0))
        plt.imshow(orig_img)
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0])*2,cnt)
        plt.title("adversarial: {}".format(label_name[adv].split(',')[0]))
        ex = np.transpose(ex, (1, 2, 0))
        plt.imshow(ex)
plt.tight_layout()
plt.show()

In [None]:
def show_p(ex):
    orig,adv,orig_img,ex,orig_top3,adv_top3 = ex

    fig,axes = plt.subplots(3,1,figsize=(6,18))
    # original image
    orig_img = np.transpose(orig_img,(1,2,0))
    axes[0].imshow(orig_img)
    # p for original image
    axes[1].bar(x=['{}({})'.format(label_name[i].split(',')[0],i) for i in orig_top3[0]],height=orig_top3[1])
    axes[1].set_title('Original Image')
    # p for adv image
    axes[2].bar(x=['{}({})'.format(label_name[i].split(',')[0],i) for i in adv_top3[0]],height=adv_top3[1])
    axes[2].set_title('Adversarial Image')
    
show_p(examples[0][0])