In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms as tt
import torchvision.models as models
import timm
from PIL import Image as pil
import os
import numpy as np
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F

resnet50 = models.resnet50(pretrained=True)
model = timm.create_model('vit_small_patch16_224', pretrained=True)
image_path="images"
old_label_path="images//old_labels"
image_list=os.listdir(image_path)

## 图片导入与数据集生成

In [101]:
def normalize(x, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]):
    mean_t = torch.Tensor(mean).reshape([1,3,1,1]).to(x.device)
    std_t = torch.Tensor(std).reshape([1,3,1,1]).to(x.device)
    y = (x-mean_t)/std_t
    return y

def get_preprocess(model_name):
    if model_name[:6]=="resnet":
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
    elif model_name[:3]=="vit":
        mean=[0.5,0.5,0.5]
        std=[0.5,0.5,0.5]
    def preprocess(images):
        y = F.interpolate(images,224)
        print(images.size())
        return normalize(y, mean=mean,
                                 std=std)
    return preprocess

In [102]:
class mydata1(Dataset):
    def __init__(self,image_path):
        self.file=[]
        for i in open(old_label_path):
            self.file.append(i)
    
    def __getitem__(self,idx):
        img=pil.open("images//"+self.file[idx].split(' ')[0])
        label=np.array(int(self.file[idx].split(' ')[1]))
        img=preprocess1(tt.ToTensor()(img)).squeeze()
        return img, label
    
    def __len__(self):
        return len(self.file)
    
    
class mydata2(Dataset):
    def __init__(self,image_path):
        self.file=[]
        for i in open(old_label_path):
            self.file.append(i)
    
    def __getitem__(self,idx):
        img=pil.open("images//"+self.file[idx].split(' ')[0])
        label=np.array(int(self.file[idx].split(' ')[1]))
        img=preprocess2(tt.ToTensor()(img)).squeeze()
        return img, label
    
    def __len__(self):
        return len(self.file)
preprocess1=get_preprocess("resnet")
res_data_Tensor = torch.utils.data.DataLoader(dataset=mydata1(image_path))
preprocess2=get_preprocess("vit")
vit_data_Tensor=torch.utils.data.DataLoader(dataset=mydata2(image_path))

##  测试

In [103]:
def test(model,test_Tensor):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_Tensor:
            test_output = model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            correct=correct+(pred_y == labels)
            total=total+1
        accuracy = correct / float(total)
            
    print('Test Accuracy of the model on the 1000 test images: %.3f' % accuracy)
    print(total)

In [82]:
test(resnet50,res_data_Tensor)

Test Accuracy of the model on the 1000 test images: 0.954
1000


In [104]:
test(vit,vit_data_Tensor)

AttributeError: dim

## 

## 对抗样本生成

### 方法1：梯度下降

In [None]:
epsilon =1/255.0 #最大扰动

loss_fn = torch.nn.CrossEntropyLoss()
def fgsm(input,epsilon,data_grad):
    grad_attack = data_grad.sign()
    result = input + epsilon*grad_attack
    result = torch.clamp(result, 0, 1)
    return result

def fgsm_attack(model,input_Tensor,epsilon):
    fgsm_sample_list=[]
    id=0
    for (data, target) in input_Tensor:
        data.requires_grad = True
        output = model(data)
        init_pred = output.max(1)[1].squeeze()
        loss = loss_fn(output, target.long())
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm(data,epsilon,data_grad)
        fgsm_sample_list.append((perturbed_data,target))
        id=id+1
        if(id%100==0):
            print(("%d/"+"1000"+"\r")%(id))
        
    return fgsm_sample_list

In [None]:
fgsm_sample={}
fgsm_sample_path='./fgsm_sample.txt'

fgsm_sample["vit"]=fgsm_attack(vit, data_Tensor,epsilon)
fgsm_sample["res"]=fgsm_attack(resnet50, data_Tensor,epsilon)

with open(fgsm_sample_path,'wb') as f:
    content = pickle.dumps(fgsm_sample)
    f.write(content)

## 攻击结果测试

In [None]:
def test1(model,input_Tensor):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in input_Tensor:
            test_output= model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
    return accuracy


In [None]:
test1(resnet50,fgsm_sample["res"])

In [None]:
print(tt.ToPILImage(fgsm_sample["vit"][0][0]))