<a href="https://colab.research.google.com/github/sayakpaul/robustness-vit/blob/master/analysis/pgd_attacks/PGD_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!gdown --id 1QtAJsTjBOf3CnrTzTTqP-nPnHcTc2g9E
!tar xf val.tar
!rm -rf val.tar
!wget https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json

Downloading...
From: https://drive.google.com/uc?id=1QtAJsTjBOf3CnrTzTTqP-nPnHcTc2g9E
To: /content/val.tar
6.75GB [01:08, 98.1MB/s]
--2021-04-12 06:07:30--  https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.215.128, 173.194.216.128, 173.194.217.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.215.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/json]
Saving to: ‘imagenet_class_index.json’


2021-04-12 06:07:30 (99.9 MB/s) - ‘imagenet_class_index.json’ saved [35363/35363]



In [None]:
!gdown --id 1Wbn3yuBBR2KO8OEI38YkHYNu2mQ96E7N

Downloading...
From: https://drive.google.com/uc?id=1Wbn3yuBBR2KO8OEI38YkHYNu2mQ96E7N
To: /content/random_hundred_paths_val.npy
  0% 0.00/16.9k [00:00<?, ?B/s]100% 16.9k/16.9k [00:00<00:00, 7.73MB/s]


In [None]:
!git clone https://github.com/jeonsworld/ViT-pytorch

Cloning into 'ViT-pytorch'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (170/170), done.[K
remote: Compressing objects: 100% (123/123), done.[K
remote: Total 170 (delta 86), reused 121 (delta 43), pack-reused 0[K
Receiving objects: 100% (170/170), 21.21 MiB | 34.69 MiB/s, done.
Resolving deltas: 100% (86/86), done.


In [None]:
!pip install -q ml-collections

[?25l[K     |███▊                            | 10kB 19.2MB/s eta 0:00:01[K     |███████▍                        | 20kB 11.9MB/s eta 0:00:01[K     |███████████                     | 30kB 8.8MB/s eta 0:00:01[K     |██████████████▉                 | 40kB 7.9MB/s eta 0:00:01[K     |██████████████████▌             | 51kB 4.6MB/s eta 0:00:01[K     |██████████████████████▏         | 61kB 5.1MB/s eta 0:00:01[K     |█████████████████████████▉      | 71kB 5.2MB/s eta 0:00:01[K     |█████████████████████████████▋  | 81kB 5.6MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 3.8MB/s 
[?25h

In [None]:
import sys
if "ViT-pytorch" not in sys.path:
  sys.path.append("ViT-pytorch")

import os
import pickle
import json

import torch
import numpy as np
import matplotlib.pyplot as plt

from urllib.request import urlretrieve
from imutils import paths

from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS

In [None]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)
    
MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]
    
HUNDRED_PATHS = np.load("random_hundred_paths_val.npy")

In [None]:
EPS = [0.001, 0.002, 0.003]
ITERATIONS = 10
RESIZE = 224

In [None]:
os.makedirs("attention_data", exist_ok=True)

if not os.path.isfile("attention_data/ViT-L_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16-224.npz", "attention_data/ViT-L_16-224.npz")

In [None]:
# Prepare Model
config = CONFIGS["ViT-L_16"]
vit_model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=RESIZE)
vit_model.load_from(np.load("attention_data/ViT-L_16-224.npz"))
vit_model.eval()

transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def generate_adversaries(image_tensor, model, true_class_index):
    delta = torch.zeros_like(image_tensor, requires_grad=True)
    optimizer = opt = torch.optim.Adam([delta], lr=1e-3)
    losses = []

    for t in range(ITERATIONS):
        inp = torch.clamp(image_tensor + delta, -1, 1)
        logits, _ = model(inp)
        loss = -torch.nn.CrossEntropyLoss()(logits, 
                                           torch.LongTensor([true_class_index]).to(DEVICE))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        delta.data.clamp_(-EPS[1], EPS[1])
        
        losses.append(loss.item())

    return delta, losses

In [None]:
def show_image(images, labels, original_label, filename):
    fig, ax = plt.subplots(ncols=3, figsize=(10, 10))
    ax[0].set_title("Input Image \n"
        f"Original Label: {original_label}\n"
        f"Prediction: {labels[0]}")
    ax[0].imshow(images[0])

    ax[1].set_title(r"$\delta$ (Zoomed in)")
    ax[1].imshow(images[1].transpose(1,2,0))

    ax[2].set_title("Perturbed Image \n"
        f"Prediction: {labels[1]}")
    ax[2].imshow(images[2].squeeze(0).detach().cpu().numpy().transpose(1,2,0))

    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    
    fig.tight_layout()
    fig.savefig(filename, dpi=300, bbox_inches="tight")
    plt.close("all")

In [None]:
def perturb_image(image_path, model):
    images = []
    labels = []

    image_idx = image_path.split(".")[0].split("_")[-1]
    sysnet_label = image_path.split("/")[1]
    true_class_index = MAPPING_DICT[sysnet_label]
    class_label = LABEL_NAMES[true_class_index]
    print("Original label:", class_label)
    
    image = Image.open(image_path)
    if image.mode == "L":
        image = np.expand_dims(np.array(image), -1)
        image = np.tile(image, [1, 1, 3])
        image = Image.fromarray(image)
    preprocessed_image = transform(image).to(DEVICE)
    model = model.to(DEVICE)
    
    logits, _ = model(preprocessed_image.unsqueeze(0))
    probs = torch.nn.functional.softmax(logits, dim=-1)
    print("Prediction before adv.:", LABEL_NAMES[probs.argmax().item()])

    delta_tensor, losses = generate_adversaries(preprocessed_image.unsqueeze(0), 
                                                model, true_class_index)
    pertubation_viz = np.clip(50*delta_tensor.detach().cpu().numpy()+0.5, -1, 1)
    pertubation_viz = (pertubation_viz + 1)/2

    pertubed_image = (preprocessed_image.unsqueeze(0) + delta_tensor).clamp_(-1, 1)

    adv_logits, _ = model(pertubed_image)
    adv_probs = torch.nn.functional.softmax(adv_logits, dim=-1)
    print("Prediction after adv.:", LABEL_NAMES[adv_probs.argmax().item()])

    images.append(image)
    images.append(pertubation_viz.squeeze(0))
    images.append((pertubed_image + 1)/2)
    labels.append(LABEL_NAMES[probs.argmax().item()])
    labels.append(LABEL_NAMES[adv_probs.argmax().item()])
    show_image(images, labels, class_label, f"{image_idx}_vit.png")

    return LABEL_NAMES[probs.argmax().item()], LABEL_NAMES[adv_probs.argmax().item()], losses

In [None]:
num_corrects = 0
adv_attacks = 0
all_losses = []

for i, image_path in enumerate(HUNDRED_PATHS):
    pred_label, adv_label, losses = perturb_image(image_path, vit_model)

    class_idx = MAPPING_DICT[image_path.split("/")[1]]
    class_label = LABEL_NAMES[class_idx]

    if class_label == pred_label:
        print(f"================{i}================")
        all_losses.append(losses)
        num_corrects += 1
        if pred_label != adv_label:
            adv_attacks += 1

print(f"Total correct predictions: {num_corrects}")
print(f"Total successful attacks: {adv_attacks}")

Original label: bow
Prediction before adv.: bow
Prediction after adv.: croquet_ball
Original label: Komodo_dragon
Prediction before adv.: Komodo_dragon
Prediction after adv.: Komodo_dragon
Original label: harvester
Prediction before adv.: harvester
Prediction after adv.: thresher
Original label: langur
Prediction before adv.: langur
Prediction after adv.: langur
Original label: patio
Prediction before adv.: patio
Prediction after adv.: patio
Original label: speedboat
Prediction before adv.: speedboat
Prediction after adv.: lifeboat
Original label: jack-o'-lantern
Prediction before adv.: jack-o'-lantern
Prediction after adv.: jack-o'-lantern
Original label: go-kart
Prediction before adv.: go-kart
Prediction after adv.: racer
Original label: purse
Prediction before adv.: purse
Prediction after adv.: velvet
Original label: Dutch_oven
Prediction before adv.: Dutch_oven
Prediction after adv.: Dutch_oven
Original label: water_bottle
Prediction before adv.: water_bottle
Prediction after adv.:

In [None]:
import pickle
     
f = open("pgd_losses_vit.pkl", "wb")
f.write(pickle.dumps(all_losses))
f.close()