This notebook provides adaptive attack on integrated gradient.

In [None]:
import torch 
import torch.nn as nn
from torchvision.transforms import ToTensor, Normalize
from torchvision.datasets import CIFAR10 
from torch.utils.data import DataLoader 
from captum.attr import IntegratedGradients
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import numpy as np 
from captum.attr import *
import quantus
import torch.autograd as autograd
import torchvision.transforms as transforms
import torchvision

In [None]:
#save image as np arrays 
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=2)

In [None]:
use_cuda=True
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
from resnet_srelu import resnet50 as resnet50

def load_model(path):
    model = resnet50()
    ckpt_dict = torch.load(path, lambda storage, loc: storage)
    model.load_state_dict(ckpt_dict)
    model.to('cuda')
    model.train(False)
    return model

modelpath = "/data/virtual environments/adv detection by robustness/adv_detection/Adaptive attacks/Models/CIFAR10/resnet50/cifar.ckpt"
model = load_model(modelpath)
model.to(device)
model.eval()

In [None]:
def adaptive_attack(model, images, labels, eps=16/255, alpha=8/255, iters=40):
    
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    
    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach() 
    
    integrated_gradients = IntegratedGradients(model)
    feature_attr_orig = integrated_gradients.attribute(images, target=labels)
    
    for i in range(iters):    
        adv_images.requires_grad = True
        outputs = model(adv_images)
        
        #calculate loss 
        cost = loss(outputs, labels)
        
        #update adversarial images 
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]
        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()
    
    adv_images = adv_images.detach().clone()
    
    steps = [301, 200, 100, 50]
    cs = [5, 10, 20,30,50]
    #cs = [0.001, 0.004, 0.01, 0.05]
    for c, num_step in zip(cs, steps): 
        for i in range(num_step):
            adv_images.requires_grad = True
            outputs = model(adv_images)
            _, target2 = torch.max(outputs.data, 1)

            #calculate loss 
            cost_pgd = loss(outputs, labels)
            feature_attr_perturbed = integrated_gradients.attribute(adv_images, target=target2) 
            l2_distance = torch.norm(feature_attr_perturbed - feature_attr_orig, p=2)

            #total cost
            cost_total = cost_pgd + c*l2_distance

            #update adversarial images 
            grad = torch.autograd.grad(cost_total, adv_images, retain_graph=True)[0]
            adv_images = adv_images.detach() + alpha*grad.sign()
            delta = torch.clamp(adv_images - images, min=-eps, max=eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()
            
    return adv_images

In [None]:
def pgd_attack(model, images, labels, eps=16/255, alpha=8/255, iters=40):
    
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    
    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach() 
        
    for i in range(iters):    
        adv_images.requires_grad = True
        outputs = model(adv_images)
        
        #calculate loss 
        cost = loss(outputs, labels)
        
        #update adversarial images 
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]
        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()
        
       
    return adv_images

In [None]:
adversarial_images = []
adversarial_labels = []
benign_images = [] 
benign_labels = []
pgd_images = []
pgd_labels= []

for step, (images, labels) in enumerate(test_loader):
    perturbed_images = adaptive_attack(model, images, labels)
    new_label = model(perturbed_images)
    _, new = torch.max(new_label.data, 1)
    adversarial_images.append(perturbed_images.detach().cpu().numpy())
    adversarial_labels.append(new.detach().cpu().numpy())
    
    # Append benign images and labels to the batch
    b_image, b_label = images.numpy(), labels.numpy()
    benign_images.append(b_image)
    benign_labels.append(b_label)
    
    #compute pgd image of the same batch too 
    pgdimages = pgd_attack(model, images, labels)
    new_label = model(pgdimages)
    _, pgdlabel = torch.max(new_label.data, 1)
    pgd_images.append(pgdimages.detach().cpu().numpy())
    pgd_labels.append(pgdlabel.detach().cpu().numpy())
    
    
    if len(adversarial_images)%10==0:
        print(len(adversarial_images))
    
    if len(adversarial_images) > 250:
        break 

In [None]:
# Concatenate the batch of adversarial images and labels into NumPy arrays
import os 
img = np.concatenate(adversarial_images)
label = np.concatenate(adversarial_labels)
b_img = np.concatenate(benign_images)
b_lbl = np.concatenate(benign_labels)
pgd_img = np.concatenate(pgd_images)
pgd_lbl = np.concatenate(pgd_labels)

In [None]:
save_dir = '/data/virtual environments/adv detection by robustness/adv_detection/Adaptive attacks/adaptive_attack_images_cifar/srelu' 
np.savez(os.path.join(save_dir, '16255IGAttackImages.npz'), adaptive_images=img, adaptive_labels=label, benign_images=b_img, benign_labels=b_lbl, pgd_images=pgd_img, pgd_labels=pgd_lbl)

In [None]:
#check time 
import time 

adv_time = []

for step, (images, labels) in enumerate(test_loader):
    start_time = time.time()
    perturbed_images = adaptive_attack(model, images, labels)
    end_time = time.time()
    exec_time = end_time - start_time 
    adv_time.append(exec_time)
    
    if step > 3:
        break 

In [None]:
adv_time

In [None]:
sum(adv_time)/(5*10)