In [4]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import os
from torchvision.utils import save_image
from dataset import Adversarial_Dataset
# from util_defense_GAN import adjust_lr, get_z_sets, get_z_star, Resize_Image
from utils import adjust_lr, get_z_sets, get_z_star, Resize_Image
from model_mnist import CNN
from classifier_model import Model_A,Model_B,Model_C,Model_D,Model_E,Model_F
# from gan_model import Generator, Decoder, Discriminator
from gan_model_mnist import Generator,Discriminator,Decoder
from torchsummary import summary
import copy
import foolbox

In [5]:
batch_size = 64
in_channel = 1
height = 28
width = 28

display_steps = 20

In [6]:
model = Model_A()

summary(model, input_size = (in_channel,height,width), device = 'cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]           1,664
              ReLU-2           [-1, 64, 28, 28]               0
            Conv2d-3           [-1, 64, 12, 12]         102,464
              ReLU-4           [-1, 64, 12, 12]               0
           Dropout-5           [-1, 64, 12, 12]               0
            Linear-6                  [-1, 128]       1,179,776
              ReLU-7                  [-1, 128]               0
           Dropout-8                  [-1, 128]               0
            Linear-9                   [-1, 10]           1,290
          Softmax-10                   [-1, 10]               0
Total params: 1,285,194
Trainable params: 1,285,194
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.98
Params size (MB): 4.90
Estimat

In [5]:
device_model = torch.device(0)

In [6]:
model.load_state_dict(torch.load('./checkpoints/F_mnist_model_A.pth'))

model = model.to(device_model)

In [7]:
learning_rate = 10.0
# rec_iters = [200, 500, 1000]
rec_iters = [200]
# rec_rrs = [10, 15, 20]
rec_rrs = [10]
decay_rate = 0.1
global_step = 3.0
generator_input_size = 28

INPUT_LATENT = 64 

In [8]:
device_generator = torch.device(0)

In [9]:
ModelG = Generator()
Dec = Decoder()

In [10]:
generator_path = './defensive_models_mnist/gen_mnist10_gp_27899.pth'
ModelG.load_state_dict(torch.load(generator_path))
dec_path = './dec_checkpoints/dec_noisy_0.3_mnist_20.pth'
Dec.load_state_dict(torch.load(dec_path))
ModelG = ModelG.to(device_generator)
Dec = Dec.to(device_generator)
print(ModelG)
print(Dec)

Generator(
  (block1): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
  )
  (block2): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
  )
  (deconv_out): ConvTranspose2d(64, 1, kernel_size=(8, 8), stride=(2, 2))
  (preprocess): Sequential(
    (0): Linear(in_features=64, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
  )
  (sigmoid): Sigmoid()
)
Decoder(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): ReLU(inplace=True)
    (4): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (5): ReLU(inplace=True)
  )
  (linear): Sequential(
    (0): Linear(in_features=4096, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=64, bias=T

In [8]:
# adversarial dataset path
root_dir = './adversarial_F_mnist/'

In [7]:
# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [13]:
test_set = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [14]:
train_set = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [15]:
print(len(train_set))

60000


### 原方法

In [17]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []
save_test_rec_loss = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            z_orig_sets, z_rec_sets, rec_loss = get_z_sets(ModelG, data, learning_rate, \
                                        device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)
            
            save_test_rec_loss.append(rec_loss)
            
            z_star = get_z_star(ModelG, data, z_rec_sets, rec_loss, device_generator, rec_rr)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()
            # print(data_hat)
            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                save_image(data.data, './rec_img_mnist/new_lgd_L{}_R{}_orig_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
                save_image(data_hat.data, './rec_img_mnist/new_lgd_L{}_R{}_rec_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
#                 print('rec_loss: ',rec_loss)
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

  1/200 average acc 0.9400
 21/200 average acc 0.9210
 41/200 average acc 0.9180
 61/200 average acc 0.9151
 81/200 average acc 0.9158
101/200 average acc 0.9176
121/200 average acc 0.9240
141/200 average acc 0.9309
161/200 average acc 0.9335
181/200 average acc 0.9367
rec_iter : 800, rec_rr : 10, Test Acc: 0.9377


In [19]:
del test_loader

### 原方法测试攻击

In [20]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

### FGSM

In [24]:
sample = Adversarial_Dataset(root_dir,'FGSM_0.3',adversarial_transform)

In [25]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [26]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []
# save_test_rec_loss = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            z_orig_sets, z_rec_sets, rec_loss = get_z_sets(ModelG, data, learning_rate, \
                                        device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)
            
#             save_test_rec_loss.append(rec_loss)
            
            z_star = get_z_star(ModelG, data, z_rec_sets, rec_loss, device_generator, rec_rr)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()
            # print(data_hat)
            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                save_image(data.data, './rec_img_mnist/lgd_FGSM_0.3_L{}_R{}_orig_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
                save_image(data_hat.data, './rec_img_mnist/lgd_FGSM_0.3_L{}_R{}_rec_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
#                 print('rec_loss: ',rec_loss)
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

  1/200 average acc 0.7000
 21/200 average acc 0.6495
 41/200 average acc 0.6371
 61/200 average acc 0.6370
 81/200 average acc 0.6380
101/200 average acc 0.6414
121/200 average acc 0.6575
141/200 average acc 0.6760
161/200 average acc 0.6846
181/200 average acc 0.6945
rec_iter : 200, rec_rr : 10, Test Acc: 0.6955


### 改造前

In [20]:
img_rec_saves = []
img_inputs = []
save_losses = []
learning_rate = 1e-3
criterion_dec = nn.MSELoss()
optim_dec = optim.Adam(Dec.parameters(), lr=learning_rate, betas=(0.5, 0.99))

In [21]:
for epoch in range(1000):
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(torch.device(0))
        labels = labels.to(torch.device(0))
#         img_inputs.append(inputs)
        
        optim_dec.zero_grad()
        
        img_rec_code = Dec(inputs)
        
        img_rec = ModelG(img_rec_code)
        # img_rec = img_rec.view(-1, inputs.size(1), inputs.size(2), inputs.size(3))
        # fake_labels = model(img_rec)
        loss_rec = criterion_dec(img_rec,inputs)
        loss_rec.backward()
        optim_dec.step()
        
        del img_rec_code
        
        running_loss += loss_rec.item()
        if(batch_idx % 100 == 99):
            # print(inputs.size())
            print('[%d, %4d] loss:%.3f'%(epoch + 1, batch_idx+1, running_loss / 100))
            save_losses.append(running_loss / 100)
            running_loss = 0.0
            
    if(epoch % 50 == 0):
        torch.save(Dec.state_dict(),'./dec_checkpoints/dec_mnist_{}.pth'.format(epoch))
        save_image(img_rec.data, './rec_img_mnist/rec_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        save_image(inputs.data, './rec_img_mnist/orig_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        img_rec_saves.append(img_rec)
        img_inputs.append(inputs)

[1,  100] loss:0.067
[1,  200] loss:0.040
[1,  300] loss:0.032
[1,  400] loss:0.028
[1,  500] loss:0.026
[1,  600] loss:0.025
[1,  700] loss:0.024
[1,  800] loss:0.023
[1,  900] loss:0.023
[1, 1000] loss:0.022
[1, 1100] loss:0.022
[1, 1200] loss:0.021
[2,  100] loss:0.020
[2,  200] loss:0.021
[2,  300] loss:0.020
[2,  400] loss:0.020
[2,  500] loss:0.020
[2,  600] loss:0.019
[2,  700] loss:0.019
[2,  800] loss:0.019
[2,  900] loss:0.019
[2, 1000] loss:0.019
[2, 1100] loss:0.019
[2, 1200] loss:0.018
[3,  100] loss:0.018
[3,  200] loss:0.018
[3,  300] loss:0.018
[3,  400] loss:0.018
[3,  500] loss:0.018
[3,  600] loss:0.017
[3,  700] loss:0.018
[3,  800] loss:0.018
[3,  900] loss:0.018
[3, 1000] loss:0.017
[3, 1100] loss:0.017
[3, 1200] loss:0.017
[4,  100] loss:0.017
[4,  200] loss:0.017
[4,  300] loss:0.017
[4,  400] loss:0.017
[4,  500] loss:0.017
[4,  600] loss:0.016
[4,  700] loss:0.016
[4,  800] loss:0.016
[4,  900] loss:0.016
[4, 1000] loss:0.017
[4, 1100] loss:0.016
[4, 1200] los

KeyboardInterrupt: 

### 改造后

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [16]:
model.eval()

CNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=3136, out_features=10, bias=True)
)

In [17]:
fmodel = foolbox.models.PyTorchModel(model, bounds = (0, 1), device = device)

In [17]:
fgsm_adv = []
fgsm_index = []
fgsm_label = []

In [18]:
img_rec_saves = []
img_inputs = []
save_losses = []
learning_rate = 1e-3
criterion_dec = nn.MSELoss()
# optim_dec = optim.Adam(Dec.parameters(), lr=learning_rate, weight_decay=1e-5, betas=(0.5, 0.9))
optim_dec = optim.Adam(Dec.parameters(), lr=learning_rate, weight_decay=1e-5)

In [19]:
def add_noise(inputs,noise_factor=0.3):
    noisy = inputs+torch.randn_like(inputs) * noise_factor
    noisy = torch.clip(noisy,0.,1.)
    return noisy

In [20]:
noise_factor = 0.05

### 原始fgsm噪声

In [21]:
for epoch in range(100):
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(torch.device(0))
        labels = labels.to(torch.device(0))
#         img_inputs.append(inputs)
        attack = foolbox.attacks.FGSM()
        raw_advs, clipped_advs, success = attack(fmodel, inputs, labels, epsilons=0.3)
        
        optim_dec.zero_grad()
        
        img_rec_code = Dec(raw_advs)
        
        img_rec = ModelG(img_rec_code)
        # img_rec = img_rec.view(-1, inputs.size(1), inputs.size(2), inputs.size(3))
        # fake_labels = model(img_rec)
        loss_rec = criterion_dec(img_rec,inputs)
        loss_rec.backward()
        optim_dec.step()
        
        del img_rec_code
        
        running_loss += loss_rec.item()
        if(batch_idx % 100 == 99):
            # print(inputs.size())
            print('[%d, %4d] loss:%.3f'%(epoch + 1, batch_idx+1, running_loss / 100))
            save_losses.append(running_loss / 100)
            running_loss = 0.0
            
    if(epoch % 10 == 0):
        torch.save(Dec.state_dict(),'./dec_checkpoints/dec_noisy_0.3_mnist_{}.pth'.format(epoch))
        save_image(img_rec.data, './rec_img_mnist/noisy_0.3_rec_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        save_image(raw_advs.data, './rec_img_mnist/noisy_0.3_orig_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        img_rec_saves.append(img_rec)
        img_inputs.append(inputs)

[1,  100] loss:0.084
[1,  200] loss:0.053
[1,  300] loss:0.041
[1,  400] loss:0.035
[1,  500] loss:0.032
[1,  600] loss:0.031
[1,  700] loss:0.030
[1,  800] loss:0.028
[1,  900] loss:0.027
[1, 1000] loss:0.027
[1, 1100] loss:0.026
[1, 1200] loss:0.025
[2,  100] loss:0.024
[2,  200] loss:0.024
[2,  300] loss:0.024
[2,  400] loss:0.023
[2,  500] loss:0.023
[2,  600] loss:0.023
[2,  700] loss:0.023
[2,  800] loss:0.022
[2,  900] loss:0.023
[2, 1000] loss:0.022
[2, 1100] loss:0.022
[2, 1200] loss:0.022
[3,  100] loss:0.021
[3,  200] loss:0.021
[3,  300] loss:0.021
[3,  400] loss:0.021
[3,  500] loss:0.021
[3,  600] loss:0.021
[3,  700] loss:0.021
[3,  800] loss:0.021
[3,  900] loss:0.021
[3, 1000] loss:0.021
[3, 1100] loss:0.021
[3, 1200] loss:0.021
[4,  100] loss:0.020
[4,  200] loss:0.020
[4,  300] loss:0.020
[4,  400] loss:0.020
[4,  500] loss:0.020
[4,  600] loss:0.020
[4,  700] loss:0.020
[4,  800] loss:0.020
[4,  900] loss:0.020
[4, 1000] loss:0.020
[4, 1100] loss:0.020
[4, 1200] los

KeyboardInterrupt: 

### 高斯噪声

In [21]:
for epoch in range(1000):
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(torch.device(0))
        labels = labels.to(torch.device(0))
#         img_inputs.append(inputs)
        image_noisy = add_noise(inputs, noise_factor)
        image_noisy = image_noisy.to(torch.device(0))
        
        optim_dec.zero_grad()
        
        img_rec_code = Dec(image_noisy)
        
        img_rec = ModelG(img_rec_code)
        # img_rec = img_rec.view(-1, inputs.size(1), inputs.size(2), inputs.size(3))
        # fake_labels = model(img_rec)
        loss_rec = criterion_dec(img_rec,inputs)
        loss_rec.backward()
        optim_dec.step()
        
        del img_rec_code
        
        running_loss += loss_rec.item()
        
        if(batch_idx % 100 == 99):
            # print(inputs.size())
            print('[%d, %4d] loss:%.3f'%(epoch + 1, batch_idx+1, running_loss / 100))
            save_losses.append(running_loss / 100)
            running_loss = 0.0
            
    if(epoch % 10 == 0):
        torch.save(Dec.state_dict(),'./dec_checkpoints/dec_GS_noisy005_mnist_{}.pth'.format(epoch))
        save_image(img_rec.data, './rec_img_mnist/GS_noisy005_rec_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        save_image(image_noisy.data, './rec_img_mnist/GS_noisy005_orig_mnist_{}.png'.format(epoch), nrow=5, normalize=True)
        img_rec_saves.append(img_rec)
        img_inputs.append(inputs)

[1,  100] loss:0.091
[1,  200] loss:0.049
[1,  300] loss:0.036
[1,  400] loss:0.030
[1,  500] loss:0.028
[1,  600] loss:0.027
[1,  700] loss:0.025
[1,  800] loss:0.024
[1,  900] loss:0.023
[1, 1000] loss:0.023
[1, 1100] loss:0.022
[1, 1200] loss:0.022
[2,  100] loss:0.021
[2,  200] loss:0.021
[2,  300] loss:0.020
[2,  400] loss:0.021
[2,  500] loss:0.020
[2,  600] loss:0.020
[2,  700] loss:0.020
[2,  800] loss:0.020
[2,  900] loss:0.020
[2, 1000] loss:0.020
[2, 1100] loss:0.019
[2, 1200] loss:0.019
[3,  100] loss:0.020
[3,  200] loss:0.019
[3,  300] loss:0.019
[3,  400] loss:0.018
[3,  500] loss:0.018
[3,  600] loss:0.019
[3,  700] loss:0.019
[3,  800] loss:0.019
[3,  900] loss:0.018
[3, 1000] loss:0.018
[3, 1100] loss:0.018
[3, 1200] loss:0.018
[4,  100] loss:0.018
[4,  200] loss:0.018
[4,  300] loss:0.018
[4,  400] loss:0.017
[4,  500] loss:0.018
[4,  600] loss:0.018
[4,  700] loss:0.018
[4,  800] loss:0.018
[4,  900] loss:0.018
[4, 1000] loss:0.018
[4, 1100] loss:0.017
[4, 1200] los

KeyboardInterrupt: 

### 重建图片准确度

In [51]:
model.eval()
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        # print(inputs.size())
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 50 == 0:
            print(labels)
            print(preds)
            save_image(rec_img.data, './rec_img_mnist/test_mix1_rec_{}.png'.format(batch_idx), nrow=5, normalize=True)
            save_image(inputs.data, './rec_img_mnist/test_mix1_orig_{}.png'.format(batch_idx), nrow=5, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4], device='cuda:0')
tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 1, 3, 5, 1, 2,
        4, 4], device='cuda:0')
  1/200 average acc 0.9600
tensor([2, 3, 3, 2, 1, 7, 0, 7, 6, 4, 1, 3, 8, 7, 4, 5, 9, 2, 5, 1, 8, 7, 3, 7,
        1, 5, 5, 0, 9, 1, 4, 0, 6, 3, 3, 6, 0, 4, 9, 7, 5, 1, 6, 8, 9, 5, 5, 7,
        9, 3], device='cuda:0')
tensor([2, 3, 3, 2, 1, 7, 0, 7, 6, 4, 1, 3, 8, 7, 4, 5, 9, 2, 5, 1, 8, 7, 3, 7,
        1, 5, 8, 0, 9, 1, 4, 0, 1, 3, 3, 6, 0, 4, 9, 7, 5, 1, 6, 8, 9, 5, 5, 7,
        9, 3], device='cuda:0')
 51/200 average acc 0.9365
tensor([3, 9, 9, 8, 4, 1, 0, 6, 0, 9, 6, 8, 6, 1, 1, 9, 8, 9, 2, 3, 5, 5, 9, 4,
        2, 1, 9, 4, 3, 9, 6, 0, 4, 0, 6, 0, 1, 2, 3, 4, 7, 8, 9, 0, 1, 2, 3, 4,
        7, 8], dev

In [23]:
del test_loader

### 测试fgsm攻击

In [53]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [69]:
sample = Adversarial_Dataset(root_dir,'FGSM_0.002',adversarial_transform)

In [70]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [71]:
model.eval()
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        # print(inputs.size())
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 50 == 0:
            print(labels)
            print(preds)
            save_image(rec_img.data, './rec_img_mnist/mix1_test_FGSM_0.002_rec_{}.png'.format(batch_idx), nrow=5, normalize=True)
            save_image(inputs.data, './rec_img_mnist/mix1_test_FGSM_0.002_orig_{}.png'.format(batch_idx), nrow=5, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))

tensor([9, 5, 2, 3, 6, 2, 2, 5, 5, 9, 2, 6, 3, 8, 8, 2, 2, 2, 2, 7, 7, 2, 6, 5,
        7, 5, 3, 1, 2, 1, 6, 7, 5, 7, 2, 3, 8, 6, 7, 9, 9, 7, 5, 8, 5, 5, 1, 9,
        2, 2], device='cuda:0')
tensor([9, 5, 2, 3, 0, 5, 1, 5, 5, 9, 2, 0, 3, 0, 3, 2, 6, 1, 8, 6, 5, 2, 6, 5,
        7, 3, 3, 1, 7, 1, 5, 7, 5, 4, 2, 3, 8, 8, 2, 4, 0, 7, 7, 3, 6, 3, 1, 8,
        2, 6], device='cuda:0')
  1/  5 average acc 0.5000
Test Acc: 0.5167


### 测试CW攻击

In [24]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [25]:
sample = Adversarial_Dataset(root_dir,'CW_0.3_real',adversarial_transform)

In [26]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [27]:
print(len(sample))

581


In [28]:
model.eval()
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        # print(inputs.size())
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 50 == 0:
            print(labels)
            print(preds)
            save_image(rec_img.data, './rec_img_mnist/FGSM_noise_test_CW_0.3_rec_{}.png'.format(batch_idx), nrow=5, normalize=True)
            save_image(inputs.data, './rec_img_mnist/FGSM_noise_test_CW_0.3_orig_{}.png'.format(batch_idx), nrow=5, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))

tensor([5, 3, 5, 9, 7, 9, 6, 2, 3, 5, 7, 5, 5, 2, 9, 9, 3, 4, 6, 2, 9, 2, 0, 2,
        5, 6, 5, 9, 3, 2, 9, 2, 6, 9, 3, 5, 8, 3, 2, 8, 7, 8, 2, 2, 1, 6, 2, 2,
        2, 2], device='cuda:0')
tensor([5, 3, 5, 9, 7, 9, 6, 2, 3, 5, 7, 5, 5, 2, 9, 8, 3, 6, 5, 3, 7, 7, 0, 2,
        8, 6, 3, 9, 5, 2, 9, 2, 0, 8, 3, 3, 3, 3, 2, 8, 7, 3, 3, 8, 1, 6, 6, 1,
        2, 7], device='cuda:0')
  1/ 12 average acc 0.6200
Test Acc: 0.6368
