This notebook reproduces combined adaptive attack on both integrated gradient and logit where the goal of an adversary is to reduce the difference between the logits and feature attribution of benign and adversarial images in addition to change in label. 

In [None]:
import torch 
import torch.nn as nn
#from torchvision.models import resnet50 
from torchvision.transforms import ToTensor, Normalize
from torchvision.datasets import CIFAR10 
from torch.utils.data import DataLoader 
from captum.attr import IntegratedGradients
import torch.nn.functional as F
#plot new and old images 
import matplotlib.pyplot as plt 
import numpy as np 
from captum.attr import *
import quantus
import torch.autograd as autograd
import torchvision.transforms as transforms
import torchvision

In [None]:
#save image as np arrays 
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=2)

In [None]:
use_cuda=True
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
from resnet_srelu import resnet50 as resnet50

def load_model(path):
    model = resnet50()
    ckpt_dict = torch.load(path, lambda storage, loc: storage)
    model.load_state_dict(ckpt_dict)
    model.to('cuda')
    model.train(False)
    return model

modelpath = "/data/virtual environments/adv detection by robustness/adv_detection/Adaptive attacks/Models/CIFAR10/resnet50/cifar.ckpt"
model = load_model(modelpath)
model.to(device)
model.eval()

In [None]:
def adaptive_attack_with_pgd_match(model, images, labels, eps=16/255, alpha=8/255, iters=40):
    
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    
    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach()
    
    clean_logits = model(images)
    
    integrated_gradients = IntegratedGradients(model)
    feature_attr_orig = integrated_gradients.attribute(images, target=labels)
    
    for i in range(iters):    
        adv_images.requires_grad = True
        outputs = model(adv_images)
        
        #calculate loss 
        cost = loss(outputs, labels)
        
        #update adversarial images 
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]
        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()
    
    adv_images = adv_images.detach().clone()
    
    steps = [301, 200, 100, 50]
    cs = [5, 10, 20,30,50]
    #cs = [0.001, 0.004, 0.01, 0.05]
    for c, num_step in zip(cs, steps): 
        for i in range(num_step):
            adv_images.requires_grad = True
            outputs = model(adv_images)
            _, target2 = torch.max(outputs.data, 1)

            #calculate loss 
            cost_pgd = loss(outputs, labels)
            feature_attr_perturbed = integrated_gradients.attribute(adv_images, target=target2) 
            l2_distance = torch.norm(feature_attr_perturbed - feature_attr_orig, p=2)
            
            #calculate logit loss
            adv_logits = model(adv_images)
            logit_loss = F.mse_loss(clean_logits, adv_logits)

            #total cost
            cost_total = cost_pgd + c*l2_distance + 10*logit_loss

            #update adversarial images 
            grad = torch.autograd.grad(cost_total, adv_images, retain_graph=True)[0]
            adv_images = adv_images.detach() + alpha*grad.sign()
            delta = torch.clamp(adv_images - images, min=-eps, max=eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()
            
    return adv_images

In [None]:
def pgd_attack(model, images, labels, eps=16/255, alpha=8/255, iters=40):
    
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    
    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach() 
        
    for i in range(iters):    
        adv_images.requires_grad = True
        outputs = model(adv_images)
        
        #calculate loss 
        cost = loss(outputs, labels)
        
        #update adversarial images 
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]
        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()
        
       
    return adv_images

In [None]:
adversarial_images = []
adversarial_labels = []
benign_images = [] 
benign_labels = []
pgd_images = []
pgd_labels= []

for step, (images, labels) in enumerate(test_loader):
    perturbed_images = adaptive_attack_with_pgd_match(model, images, labels)
    new_label = model(perturbed_images)
    _, new = torch.max(new_label.data, 1)
    adversarial_images.append(perturbed_images.detach().cpu().numpy())
    adversarial_labels.append(new.detach().cpu().numpy())
    
    # Append benign images and labels to the batch
    b_image, b_label = images.numpy(), labels.numpy()
    benign_images.append(b_image)
    benign_labels.append(b_label)
    
    #compute pgd image of the same batch too 
    pgdimages = pgd_attack(model, images, labels)
    new_label = model(pgdimages)
    _, pgdlabel = torch.max(new_label.data, 1)
    pgd_images.append(pgdimages.detach().cpu().numpy())
    pgd_labels.append(pgdlabel.detach().cpu().numpy())
    
    
    if len(adversarial_images)%10==0:
        print(len(adversarial_images))
    
    if len(adversarial_images) > 250:
        break 

In [None]:
# Concatenate the batch of adversarial images and labels into NumPy arrays
import os 
img = np.concatenate(adversarial_images)
label = np.concatenate(adversarial_labels)
b_img = np.concatenate(benign_images)
b_lbl = np.concatenate(benign_labels)
pgd_img = np.concatenate(pgd_images)
pgd_lbl = np.concatenate(pgd_labels)

In [None]:
save_dir = '/data/virtual environments/adv detection by robustness/adv_detection/Adaptive attacks/adaptive_attack_images_cifar/srelu' 
np.savez(os.path.join(save_dir, 'attack both ig and model/16255.npz'), adaptive_images=img, adaptive_labels=label, benign_images=b_img, benign_labels=b_lbl, pgd_images=pgd_img, pgd_labels=pgd_lbl)

In [None]:
data_path = '/data/virtual environments/adv detection by robustness/adv_detection/Adaptive attacks/adaptive_attack_images_cifar/adaptive2.npz' 

In [None]:
npobj = np.load(data_path)
adaptive_image = npobj['adaptive_images']
adaptive_label = npobj['adaptive_labels']
ben_image = npobj['benign_images']
ben_label = npobj['benign_labels']
pgd_image = npobj['pgd_images']
pgd_label =npobj['pgd_labels']

In [None]:
adaptive_label, pgd_label

In [None]:
ben_label

In [None]:
# Define the class names for CIFAR
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Plot the images
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 5))

for i, ax in enumerate(axes.flatten()):
    # Access the image and label at the current index
    image = np.transpose(ben_image[i], (1, 2, 0))  # Transpose to (height, width, channels)
    label = ben_label[i]

    # Plot the image
    ax.imshow(image)
    ax.set_title(class_names[label])
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Define the class names for CIFAR
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Plot the images
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 5))

for i, ax in enumerate(axes.flatten()):
    # Access the image and label at the current index
    image = np.transpose(adaptive_image[i], (1, 2, 0))  # Transpose to (height, width, channels)
    label = adaptive_label[i]

    # Plot the image
    ax.imshow(image)
    ax.set_title(class_names[label])
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Define the class names for CIFAR
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Plot the images
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 5))

for i, ax in enumerate(axes.flatten()):
    # Access the image and label at the current index
    image = np.transpose(pgd_image[i], (1, 2, 0))  # Transpose to (height, width, channels)
    label = pgd_label[i]

    # Plot the image
    ax.imshow(image)
    ax.set_title(class_names[label])
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
#get heatmaps 

images, labels = images.to(device), labels.to(device)
intgrad1 = quantus.normalise_func.normalise_by_negative(IntegratedGradients(model).attribute(inputs=images, target=labels, baselines=torch.zeros_like(images)).sum(axis=1).cpu().numpy())
intgrad2 = quantus.normalise_func.normalise_by_negative(IntegratedGradients(model).attribute(inputs=perturbed_images, target=new, baselines=torch.zeros_like(perturbed_images)).sum(axis=1).cpu().numpy())

In [None]:
#perform simple pgd attack 

def pgd_attack(model, images, labels, eps=8/255, alpha=2/255, iters=40):
    
    images = images.clone().detach().to(device)
    labels = labels.clone().detach().to(device)
    
    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach() 
        
    for i in range(iters):    
        adv_images.requires_grad = True
        outputs = model(adv_images)
        
        #calculate loss 
        cost = loss(outputs, labels)
        
        #update adversarial images 
        grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0]
        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()
        
       
    return adv_images

#compute pgd attack 
for images, labels in test_loader:
    pgd_images = pgd_attack(model, images, labels)
    print('Original label:', labels)
    new_label = model(pgd_images)
    _, pgdlabel = torch.max(new_label.data, 1)
    print('pgdlabel:', pgdlabel)
    break  # Break after processing one batch

#compute pgd attributions 
intgrad3 = quantus.normalise_func.normalise_by_negative(IntegratedGradients(model).attribute(inputs=pgd_images, target=pgdlabel, baselines=torch.zeros_like(pgd_images)).sum(axis=1).cpu().numpy())

In [None]:
#when i put some higher constants on c


In [None]:
#nr_images = x_batch.shape[0]
nr_images = 5
fig, axes = plt.subplots(nrows=nr_images, ncols=6, figsize=(nr_images*2.5, int(nr_images*3)))
for i,j in zip(range(nr_images),range(0,5)):
    
    
    axes[i, 0].imshow((np.moveaxis((images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 0].title.set_text(f"normal")
    axes[i, 0].axis("off")
    
    axes[i, 1].imshow(np.moveaxis(intgrad1[j], 0,-1), cmap="seismic")
    axes[i, 1].title.set_text(f"")
    axes[i, 1].axis("off")
    
    axes[i, 2].imshow((np.moveaxis((pgd_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 2].title.set_text(f"pgd")
    axes[i, 2].axis("off")
    
    axes[i, 3].imshow(np.moveaxis(intgrad3[j], 0,-1), cmap="seismic")
    axes[i, 3].title.set_text(f"")
    axes[i, 3].axis("off")

    axes[i, 4].imshow((np.moveaxis((perturbed_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 4].title.set_text(f"adv")
    axes[i, 4].axis("off")
    
    axes[i, 5].imshow(np.moveaxis(intgrad2[j], 0,-1), cmap="seismic")
    axes[i, 5].title.set_text(f"")
    axes[i, 5].axis("off")
        
plt.tight_layout()
plt.show()

In [None]:
# when i had no c in loss function with smaller c value

In [None]:
#nr_images = x_batch.shape[0]
nr_images = 5
fig, axes = plt.subplots(nrows=nr_images, ncols=6, figsize=(nr_images*2.5, int(nr_images*3)))
for i,j in zip(range(nr_images),range(0,5)):
    
    
    axes[i, 0].imshow((np.moveaxis((images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 0].title.set_text(f"normal")
    axes[i, 0].axis("off")
    
    axes[i, 1].imshow(np.moveaxis(intgrad1[j], 0,-1), cmap="seismic")
    axes[i, 1].title.set_text(f"")
    axes[i, 1].axis("off")
    
    axes[i, 2].imshow((np.moveaxis((pgd_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 2].title.set_text(f"pgd")
    axes[i, 2].axis("off")
    
    axes[i, 3].imshow(np.moveaxis(intgrad3[j], 0,-1), cmap="seismic")
    axes[i, 3].title.set_text(f"")
    axes[i, 3].axis("off")

    axes[i, 4].imshow((np.moveaxis((perturbed_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 4].title.set_text(f"adv")
    axes[i, 4].axis("off")
    
    axes[i, 5].imshow(np.moveaxis(intgrad2[j], 0,-1), cmap="seismic")
    axes[i, 5].title.set_text(f"")
    axes[i, 5].axis("off")
        
plt.tight_layout()
plt.show()

In [None]:
# when i had c in loss function with smaller c value

In [None]:
#nr_images = x_batch.shape[0]
nr_images = 5
fig, axes = plt.subplots(nrows=nr_images, ncols=6, figsize=(nr_images*2.5, int(nr_images*3)))
for i,j in zip(range(nr_images),range(0,5)):
    
    
    axes[i, 0].imshow((np.moveaxis((images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 0].title.set_text(f"normal")
    axes[i, 0].axis("off")
    
    axes[i, 1].imshow(np.moveaxis(intgrad1[j], 0,-1), cmap="seismic")
    axes[i, 1].title.set_text(f"")
    axes[i, 1].axis("off")
    
    axes[i, 2].imshow((np.moveaxis((pgd_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 2].title.set_text(f"pgd")
    axes[i, 2].axis("off")
    
    axes[i, 3].imshow(np.moveaxis(intgrad3[j], 0,-1), cmap="seismic")
    axes[i, 3].title.set_text(f"")
    axes[i, 3].axis("off")

    axes[i, 4].imshow((np.moveaxis((perturbed_images[j].cpu().numpy()), 0, -1)*255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i, 4].title.set_text(f"adv")
    axes[i, 4].axis("off")
    
    axes[i, 5].imshow(np.moveaxis(intgrad2[j], 0,-1), cmap="seismic")
    axes[i, 5].title.set_text(f"")
    axes[i, 5].axis("off")
        
plt.tight_layout()
plt.show()

In [None]:
import time

In [None]:
adaptive_time

In [None]:
t1= sum(adaptive_time)/(10*3)
t1

In [None]:
# check time 
adaptive_time = [] 


for step, (images, labels) in enumerate(test_loader):
    
    start_time = time.time()
    perturbed_images = adaptive_attack_with_pgd_match(model, images, labels)
    end_time = time.time()
    exec_time = end_time - start_time 
    adaptive_time.append(exec_time)
        
    if step > 5:
        break 