<a href="https://colab.research.google.com/github/sayakpaul/robustness-vit/blob/master/analysis/pgd_attacks/PGD_BiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
!gdown --id 1QtAJsTjBOf3CnrTzTTqP-nPnHcTc2g9E
!tar xf val.tar
!rm -rf val.tar

Downloading...
From: https://drive.google.com/uc?id=1QtAJsTjBOf3CnrTzTTqP-nPnHcTc2g9E
To: /content/val.tar
6.75GB [01:13, 91.6MB/s]


In [None]:
!wget https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
!gdown --id 1Wbn3yuBBR2KO8OEI38YkHYNu2mQ96E7N

--2021-04-12 05:48:47--  https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.119.128, 108.177.126.128, 108.177.127.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.119.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/json]
Saving to: ‘imagenet_class_index.json’


2021-04-12 05:48:52 (154 MB/s) - ‘imagenet_class_index.json’ saved [35363/35363]

Downloading...
From: https://drive.google.com/uc?id=1Wbn3yuBBR2KO8OEI38YkHYNu2mQ96E7N
To: /content/random_hundred_paths_val.npy
100% 16.9k/16.9k [00:00<00:00, 7.73MB/s]


In [None]:
import tensorflow as tf
import tensorflow_hub as hub

from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import pickle
import json

In [None]:
with open("imagenet_class_index.json", "r") as read_file:
    imagenet_labels = json.load(read_file)
    
MAPPING_DICT = {}
LABEL_NAMES = {}
for label_id in list(imagenet_labels.keys()):
    MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)
    LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]
    
HUNDRED_PATHS = HUNDRED_PATHS = np.load("random_hundred_paths_val.npy")

## Utilities

In [None]:
EPS = [0.001, 0.002, 0.003]
ITERATIONS = 10
RESIZE = 224

In [None]:
# Function to preprocess an image for performing inference
def preprocess_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image)
    if image.shape[-1] == 1:
        image = tf.tile(image, [1, 1, 3])
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (RESIZE, RESIZE))
    image = tf.expand_dims(image, 0)

    class_idx = MAPPING_DICT[image_path.split("/")[1]]
    class_label = LABEL_NAMES[class_idx]
    return image, class_idx, class_label

# Clipping utility to project delta 
def clip_eps(delta_tensor):
    return tf.clip_by_value(delta_tensor, 
                            clip_value_min=-EPS[0], 
                            clip_value_max=EPS[0])

In [None]:
# m-r101x3 because it's somewhat comparable to ViT_L-16
BIT_URL = "https://tfhub.dev/google/bit/m-r101x3/ilsvrc2012_classification/1"
bit_module = tf.keras.Sequential([hub.KerasLayer(BIT_URL)])

## Attack Utilities

In [None]:
def generate_adversaries(image, delta, model, true_class_index):
    # Loss and optimizer
    scc_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    losses = []

    for t in range(ITERATIONS):
        with tf.GradientTape() as tape:
            tape.watch(delta)
            inp = tf.clip_by_value(image + delta, 0, 1)
            predictions = model(inp, training=False)
            loss = - scc_loss(
                    tf.convert_to_tensor([true_class_index]),
                    predictions
                )
            
        # Get the gradients
        gradients = tape.gradient(loss, delta)
        
        # Update the weights
        optimizer.apply_gradients([(gradients, delta)])

        # Project the delta back (l-infinite norm)
        delta.assign_add(clip_eps(delta))
        losses.append(loss)

    return delta, losses

In [None]:
def show_image(images, labels, original_label, filename):
    fig, ax = plt.subplots(ncols=3, figsize=(10, 10))
    ax[0].set_title("Input Image \n"
        f"Original Label: {original_label}\n"
        f"Prediction: {labels[0]}")
    ax[0].imshow(tf.squeeze(images[0], 0))

    ax[1].set_title(r"$\delta$ (Zoomed in)")
    ax[1].imshow(tf.squeeze(images[1], 0))

    ax[2].set_title("Perturbed Image \n"
        f"Prediction: {labels[1]}")
    ax[2].imshow(tf.squeeze(images[2], 0))

    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    
    fig.tight_layout()
    fig.savefig(filename, dpi=300, bbox_inches="tight")
    plt.close("all")

In [None]:
def perturb_image(image_path, model):
    images = []
    labels = []

    preprocessed_image, true_class_index, class_label = preprocess_image(image_path)
    image_idx = image_path.split(".")[0].split("_")[-1]
    print("Original label:", class_label)

    # Generate predictions before any adversaries
    initial_pred = model.predict(preprocessed_image)
    print("Prediction before adv.:", LABEL_NAMES[initial_pred.argmax()])

    # Initialize the perturbation quantity
    image_tensor = tf.constant(preprocessed_image, dtype=tf.float32)
    delta = tf.Variable(tf.zeros_like(image_tensor), trainable=True)

    # Get the learned delta 
    delta_tensor, losses = generate_adversaries(image_tensor, delta, 
                                                model, true_class_index)

    # Pertubed image
    pertubed_image = (image_tensor + delta_tensor)
    pertubed_image = tf.clip_by_value(pertubed_image, 0, 1)

    # Generate prediction
    adv_pred = model.predict(pertubed_image)
    print("Prediction after adv.:", LABEL_NAMES[adv_pred.argmax()])

    images.append(preprocessed_image)
    images.append(tf.clip_by_value(50*delta_tensor.numpy()+0.5, 0, 1))
    images.append(pertubed_image)
    labels.append(LABEL_NAMES[initial_pred.argmax()])
    labels.append(LABEL_NAMES[adv_pred.argmax()])
    show_image(images, labels, class_label, f"{image_idx}_bit.png")

    return LABEL_NAMES[initial_pred.argmax()], LABEL_NAMES[adv_pred.argmax()], losses

## Assessment

In [None]:
num_corrects = 0
adv_attacks = 0
all_losses = []

for i, image_path in enumerate(HUNDRED_PATHS):
    pred_label, adv_label, losses = perturb_image(image_path, bit_module)

    class_idx = MAPPING_DICT[image_path.split("/")[1]]
    class_label = LABEL_NAMES[class_idx]

    if class_label == pred_label:
        print(f"================{i}================")
        all_losses.append(losses)
        num_corrects += 1
        if pred_label != adv_label:
            adv_attacks += 1

print(f"Total correct predictions: {num_corrects}")
print(f"Total successful attacks: {adv_attacks}")

Original label: bow
Prediction before adv.: croquet_ball
Prediction after adv.: croquet_ball
Original label: Komodo_dragon
Prediction before adv.: Komodo_dragon
Prediction after adv.: sea_lion
Original label: harvester
Prediction before adv.: thresher
Prediction after adv.: thresher
Original label: langur
Prediction before adv.: langur
Prediction after adv.: coffee_mug
Original label: patio
Prediction before adv.: patio
Prediction after adv.: park_bench
Original label: speedboat
Prediction before adv.: speedboat
Prediction after adv.: canoe
Original label: jack-o'-lantern
Prediction before adv.: jack-o'-lantern
Prediction after adv.: television
Original label: go-kart
Prediction before adv.: go-kart
Prediction after adv.: racer
Original label: purse
Prediction before adv.: wool
Prediction after adv.: wool
Original label: Dutch_oven
Prediction before adv.: Dutch_oven
Prediction after adv.: stove
Original label: water_bottle
Prediction before adv.: water_bottle
Prediction after adv.: sal

In [None]:
f = open("pgd_losses_bit.pkl", "wb")
f.write(pickle.dumps(all_losses))
f.close()