In [1]:
#内嵌画图
%matplotlib inline
#调试开关
import logging
#logging.basicConfig(level=logging.INFO,format="%(filename)s[line:%(lineno)d] %(levelname)s %(message)s")
#logger=logging.getLogger(__name__)

import sys

#添加系统路径
sys.path.append("../AdvBox/")

import torch
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.utils.data.dataloader as Data
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.models.pytorch import PytorchModel
from tutorials.mnist_model_pytorch import Net

# 自适应使用GPU还是CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data: torch.Size([60000, 28, 28])
train_labels: torch.Size([60000])
test_data: torch.Size([10000, 28, 28])


In [8]:
import torch.nn as nn
import matplotlib.pyplot as plt

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2,mode="nearest"),
            nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2,mode="nearest"),
            nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1),
        )

    def forward(self, x):
        output = self.encoder(x)
        output = self.decoder(output)
        return output

autoencoder = AutoEncoder().to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
loss_func = nn.MSELoss()

#使用MNIST训练数据集 

train_data=datasets.MNIST('../AdvBox/tutorials/mnist-pytorch/data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ]))

train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=128, shuffle=True)

#迭代训练10轮
for epoch in range(20):

    for i, data in enumerate(train_loader):
        inputs, labels = data
                    
        inputs, labels = inputs.to(device), labels.to(device)
        
        #增加噪声
        inputs_noise=inputs+0.1*torch.randn(inputs.shape).to(device)
        inputs_noise=torch.clamp(inputs_noise,0.0,1.0)
        
        #output = autoencoder(inputs)
        #encoded, decoded = autoencoder(inputs_noise)
        output = autoencoder(inputs_noise)

        loss = loss_func(output, inputs)      
        optimizer.zero_grad()               
        loss.backward()                     
        optimizer.step()                  

        if (i % 100 == 0) and ( i > 0 ):
            print("Epoch={} batch={} loss={}".format(epoch,i, loss.data.cpu().numpy()))


    

Epoch=0 batch=100 loss=0.0360135696828
Epoch=0 batch=200 loss=0.0240466762334
Epoch=0 batch=300 loss=0.0190817210823
Epoch=0 batch=400 loss=0.016608748585
Epoch=1 batch=100 loss=0.0138044580817
Epoch=1 batch=200 loss=0.0128854429349
Epoch=1 batch=300 loss=0.0124944038689
Epoch=1 batch=400 loss=0.0109587656334
Epoch=2 batch=100 loss=0.0110200578347
Epoch=2 batch=200 loss=0.0106793874875
Epoch=2 batch=300 loss=0.00996408239007
Epoch=2 batch=400 loss=0.0106243239716
Epoch=3 batch=100 loss=0.00986938923597
Epoch=3 batch=200 loss=0.00985191483051
Epoch=3 batch=300 loss=0.00996092427522
Epoch=3 batch=400 loss=0.0096589429304
Epoch=4 batch=100 loss=0.00975240394473
Epoch=4 batch=200 loss=0.00950953178108
Epoch=4 batch=300 loss=0.00894741527736
Epoch=4 batch=400 loss=0.00893811229616
Epoch=5 batch=100 loss=0.00924315303564
Epoch=5 batch=200 loss=0.0088550504297
Epoch=5 batch=300 loss=0.00909268576652
Epoch=5 batch=400 loss=0.00888279732317
Epoch=6 batch=100 loss=0.00893933326006
Epoch=6 batch=

# 验证自编码器去噪对模型识别的影响

In [9]:
import numpy as np


TOTAL_NUM = 1000
pretrained_model="../AdvBox/tutorials/mnist-pytorch/net.pth"


#使用MNIST测试数据集 随机挑选TOTAL_NUM个
# Pytorch下的MNIST数据集默认就是归一化了
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../AdvBox/tutorials/mnist-pytorch/data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=1, shuffle=True)

# Define what device we are using
logging.info("CUDA Available: {}".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the network
model = Net().to(device)

# Load the pretrained model
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))

# Set the model in evaluation mode. In this case this is for the Dropout layers
model.eval()

# use test data to generate adversarial examples
total_count = 0

#去噪前正确识别个数
pre_count=0

#去噪后正确识别个数
decoded_count = 0



for i, data in enumerate(test_loader):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    total_count += 1
    
    #去噪前
    pre_label=np.argmax(model(inputs).data.cpu().numpy())
    
    if pre_label == labels[0]:
        pre_count+=1
        
    
    
    
    #使用自编码器去噪
    output = autoencoder(inputs)
    output=output.view(1,1,28,28)
    
    decoded_label=np.argmax(model(output).data.cpu().numpy())
        
        
    if decoded_label == labels[0]:
        decoded_count+=1
    
       

    if total_count >= TOTAL_NUM:
        print(
            "[TEST_DATASET]: pre_count=%d, total_count=%d, pre_count_rate=%f  decoded_count=%d decoded_count_rate=%f"
            % (pre_count, total_count, float(pre_count) / total_count,decoded_count, float(decoded_count) / total_count  ) 
             )
        break



[TEST_DATASET]: pre_count=991, total_count=1000, pre_count_rate=0.991000  decoded_count=982 decoded_count_rate=0.982000


# 使用自编码器过滤噪音

In [10]:

import numpy as np


TOTAL_NUM = 1000
pretrained_model="../AdvBox/tutorials/mnist-pytorch/net.pth"
loss_func = torch.nn.CrossEntropyLoss()

#使用MNIST测试数据集 随机挑选TOTAL_NUM个
# Pytorch下的MNIST数据集默认就是归一化了
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../AdvBox/tutorials/mnist-pytorch/data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=1, shuffle=True)

# Define what device we are using
logging.info("CUDA Available: {}".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the network
model = Net().to(device)

# Load the pretrained model
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))

# Set the model in evaluation mode. In this case this is for the Dropout layers
model.eval()

# advbox demo
m = PytorchModel(
    model, loss_func,(0, 1),
    channel_axis=1)

#实例化FGSM
attack = FGSM(m)
#设置攻击步长为0.1
attack_config = {"epsilons": 0.01}

# use test data to generate adversarial examples
total_count = 0
# 去噪前的攻击成功个数
fooling_count = 0
# 去噪后的攻击成功个数
decoded_fooling_count = 0

#记录原始数据经过自编码器去噪后可以正常识别的个数
decoded_count = 0



for i, data in enumerate(test_loader):
    inputs, labels = data
    inputs, labels=inputs.numpy(),labels.numpy()
    
    

    total_count += 1
    adversary = Adversary(inputs, labels[0])

    # FGSM non-targeted attack
    adversary = attack(adversary, **attack_config)

    if adversary.is_successful():
        fooling_count += 1
        print(
            'attack success, original_label=%d, adversarial_label=%d, count=%d'
            % (labels, adversary.adversarial_label, total_count))
         
        #对抗样本保存在adversary.adversarial_example
        #adversary_image=np.copy(adversary.adversarial_example[0])
        adversary_image = Variable(torch.from_numpy(np.copy(adversary.adversarial_example)).to(device).float())
        pre_label=np.argmax(model(adversary_image).data.cpu().numpy())
        
        
        #使用自编码器去噪
        output = autoencoder(adversary_image)
            
        #print(decoded.shape)
        output=output.view(1,1,28,28)
        #print(decoded.shape)
        
        decoded_label=np.argmax(model(output).data.cpu().numpy())
        
        
        if decoded_label != labels[0]:
            print("orig_label={} adv_label={} decoded_label={}".format(labels[0],pre_label,decoded_label))
            decoded_fooling_count+=1
            
            
    else:
        print('attack failed, original_label=%d, count=%d' %
              (labels, total_count))

    if total_count >= TOTAL_NUM:
        print(
            "[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f decoded_fooling_count=%d  decoded_fooling_count_rate=%f"
            % (fooling_count, total_count, float(fooling_count) / total_count,decoded_fooling_count,float(decoded_fooling_count) / total_count))
        break
print("fgsm attack done")


cuda
attack success, original_label=5, adversarial_label=0, count=1
attack success, original_label=9, adversarial_label=4, count=2
attack success, original_label=2, adversarial_label=7, count=3
attack success, original_label=3, adversarial_label=5, count=4
attack success, original_label=4, adversarial_label=7, count=5
attack success, original_label=5, adversarial_label=8, count=6
attack success, original_label=2, adversarial_label=8, count=7
attack success, original_label=4, adversarial_label=9, count=8
orig_label=4 adv_label=9 decoded_label=9
attack success, original_label=4, adversarial_label=9, count=9
attack success, original_label=1, adversarial_label=4, count=10
attack success, original_label=9, adversarial_label=4, count=11
attack success, original_label=8, adversarial_label=5, count=12
attack success, original_label=0, adversarial_label=6, count=13
attack success, original_label=6, adversarial_label=0, count=14
attack success, original_label=2, adversarial_label=7, count=15
att

attack success, original_label=2, adversarial_label=7, count=126
attack success, original_label=2, adversarial_label=7, count=127
attack success, original_label=2, adversarial_label=4, count=128
attack success, original_label=7, adversarial_label=9, count=129
attack success, original_label=7, adversarial_label=9, count=130
attack success, original_label=7, adversarial_label=0, count=131
attack success, original_label=2, adversarial_label=8, count=132
attack success, original_label=4, adversarial_label=8, count=133
attack success, original_label=4, adversarial_label=7, count=134
attack success, original_label=5, adversarial_label=0, count=135
attack success, original_label=8, adversarial_label=9, count=136
attack success, original_label=3, adversarial_label=5, count=137
attack success, original_label=7, adversarial_label=5, count=138
attack success, original_label=7, adversarial_label=9, count=139
attack success, original_label=9, adversarial_label=4, count=140
attack success, original_

attack success, original_label=2, adversarial_label=0, count=249
attack success, original_label=5, adversarial_label=0, count=250
attack success, original_label=1, adversarial_label=4, count=251
attack success, original_label=2, adversarial_label=7, count=252
attack success, original_label=4, adversarial_label=6, count=253
orig_label=4 adv_label=6 decoded_label=6
attack success, original_label=8, adversarial_label=5, count=254
attack success, original_label=3, adversarial_label=5, count=255
attack success, original_label=9, adversarial_label=4, count=256
attack success, original_label=6, adversarial_label=0, count=257
attack success, original_label=6, adversarial_label=8, count=258
attack success, original_label=3, adversarial_label=9, count=259
attack success, original_label=2, adversarial_label=8, count=260
attack success, original_label=7, adversarial_label=9, count=261
attack success, original_label=1, adversarial_label=7, count=262
attack success, original_label=4, adversarial_lab

attack success, original_label=4, adversarial_label=7, count=373
attack success, original_label=2, adversarial_label=8, count=374
attack success, original_label=4, adversarial_label=7, count=375
attack success, original_label=4, adversarial_label=6, count=376
attack success, original_label=1, adversarial_label=7, count=377
attack success, original_label=1, adversarial_label=7, count=378
attack success, original_label=2, adversarial_label=8, count=379
attack success, original_label=5, adversarial_label=9, count=380
attack success, original_label=7, adversarial_label=2, count=381
attack success, original_label=7, adversarial_label=5, count=382
attack success, original_label=8, adversarial_label=9, count=383
attack success, original_label=1, adversarial_label=4, count=384
attack success, original_label=6, adversarial_label=0, count=385
attack success, original_label=1, adversarial_label=7, count=386
attack success, original_label=1, adversarial_label=7, count=387
attack success, original_

attack success, original_label=6, adversarial_label=5, count=498
attack success, original_label=3, adversarial_label=5, count=499
attack success, original_label=6, adversarial_label=5, count=500
attack success, original_label=6, adversarial_label=5, count=501
attack success, original_label=4, adversarial_label=9, count=502
attack success, original_label=8, adversarial_label=7, count=503
attack success, original_label=4, adversarial_label=9, count=504
attack success, original_label=1, adversarial_label=4, count=505
attack success, original_label=1, adversarial_label=7, count=506
attack success, original_label=1, adversarial_label=4, count=507
attack success, original_label=3, adversarial_label=9, count=508
attack success, original_label=7, adversarial_label=4, count=509
attack success, original_label=4, adversarial_label=9, count=510
orig_label=4 adv_label=9 decoded_label=9
attack success, original_label=6, adversarial_label=0, count=511
attack success, original_label=6, adversarial_lab

attack success, original_label=6, adversarial_label=5, count=624
attack success, original_label=4, adversarial_label=8, count=625
attack success, original_label=1, adversarial_label=4, count=626
attack success, original_label=8, adversarial_label=9, count=627
attack success, original_label=5, adversarial_label=8, count=628
attack success, original_label=9, adversarial_label=4, count=629
attack success, original_label=5, adversarial_label=8, count=630
attack success, original_label=5, adversarial_label=8, count=631
attack success, original_label=1, adversarial_label=7, count=632
attack success, original_label=1, adversarial_label=4, count=633
attack success, original_label=0, adversarial_label=1, count=634
orig_label=0 adv_label=1 decoded_label=1
attack success, original_label=3, adversarial_label=8, count=635
attack success, original_label=0, adversarial_label=6, count=636
attack success, original_label=8, adversarial_label=9, count=637
attack success, original_label=0, adversarial_lab

attack success, original_label=2, adversarial_label=8, count=751
attack success, original_label=5, adversarial_label=9, count=752
attack success, original_label=2, adversarial_label=0, count=753
attack success, original_label=2, adversarial_label=1, count=754
attack success, original_label=3, adversarial_label=0, count=755
attack success, original_label=4, adversarial_label=9, count=756
attack success, original_label=0, adversarial_label=9, count=757
attack success, original_label=7, adversarial_label=3, count=758
attack success, original_label=1, adversarial_label=4, count=759
attack success, original_label=3, adversarial_label=9, count=760
attack success, original_label=6, adversarial_label=4, count=761
attack success, original_label=1, adversarial_label=4, count=762
attack success, original_label=8, adversarial_label=9, count=763
orig_label=8 adv_label=9 decoded_label=9
attack success, original_label=3, adversarial_label=7, count=764
attack success, original_label=4, adversarial_lab

attack success, original_label=3, adversarial_label=8, count=874
attack success, original_label=3, adversarial_label=0, count=875
attack success, original_label=0, adversarial_label=9, count=876
attack success, original_label=7, adversarial_label=9, count=877
attack success, original_label=7, adversarial_label=9, count=878
attack success, original_label=0, adversarial_label=5, count=879
attack success, original_label=9, adversarial_label=4, count=880
orig_label=9 adv_label=4 decoded_label=4
attack success, original_label=0, adversarial_label=9, count=881
attack success, original_label=9, adversarial_label=8, count=882
attack success, original_label=0, adversarial_label=6, count=883
attack success, original_label=7, adversarial_label=5, count=884
attack success, original_label=8, adversarial_label=0, count=885
attack success, original_label=1, adversarial_label=4, count=886
attack success, original_label=3, adversarial_label=8, count=887
attack success, original_label=9, adversarial_lab

attack success, original_label=5, adversarial_label=8, count=1000
[TEST_DATASET]: fooling_count=1000, total_count=1000, fooling_rate=1.000000 decoded_fooling_count=37  decoded_fooling_count_rate=0.037000
fgsm attack done
