## Example from OpenAI

In [1]:
## Example from OpenAI
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]}: {100 * value.item():.2f}%")

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified

Top predictions:

snake: 65.19%
turtle: 12.44%
sweet_pepper: 3.85%
lizard: 1.88%
crocodile: 1.74%


In [2]:
import pandas as pd

In [10]:
def make_adverserial_inference_df(inference_path, fgsm = False):
    '''
    Create dataframe for our inference files
    inference_path: path to our inference folder 
    
    retunrs df with image pointing to image filename and caption to the correct label 
    '''
    #file_path_names = [f for f in os.listdir(inference_path) if os.path.isfile(f)]
    file_path_names = os.listdir(inference_path)
    d = {} 
    d['image'] = file_path_names
    labels = []
    captions = []
    for file_path in file_path_names:
        if fgsm:
            name = os.path.splitext(file_path)[0] # remove .JPEG from the end 
            name = ' '.join(name.split('_'))  # "01669191_box_turtle"  to "box turtle"
            label = name
            name = "This is a photo of a " + name  # "box turtle" to "This is box turtle"
        else:
            name = os.path.splitext(file_path)[0] # remove .JPEG from the end 
            name = ' '.join(name.split('_')[1:])  # "01669191_box_turtle"  to "box turtle"
            label = name
            name = "This is a photo of a " + name  # "box turtle" to "This is box turtle"
        captions.append(name)
        labels.append(label)
    d['caption'] = captions 
    d['label'] = labels
    df = pd.DataFrame(data=d)
    
    return df 

In [11]:
def find_match(image_input, text_inputs, model, classes, k=5, print_only=False):
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)

    # Pick the top 5 most similar labels for the image
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(k)
    if print_only:
        for value, index in zip(values, indices):
            print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")
    matches = [classes[index] for index in indices]
    return matches 

In [51]:
from tqdm import tqdm
from PIL import Image
def top_k_inf(model, df, k=5, attack_mode=None):
    
    # appending the correct folder path
    if attack_mode is None:
        source_path = "data/originalImagenet/"
    elif attack_mode == "pgd":
        source_path = "data/pgdAttack/"
    elif attack_mode == "fgsm":
        source_path = "data/attack/"
    correct = 0
    
    text_inputs = df['caption'].values
    text_inputs = torch.cat([clip.tokenize(c) for c in text_inputs]).to(device)

    for index,row in tqdm(df.iterrows(), total=len(df)): 
        image_path = source_path + row['image']
        image = Image.open(image_path).convert('RGB')
        image_input = preprocess(image).unsqueeze(0).to(device)

        matches = find_match(image_input, text_inputs, model, df['caption'].values, k,  False)

        for match in matches:
            if row['label'] in match: 
                correct += 1
        if index % 100 == 0:
            print(f"Correct raw: {correct}, Percent: {round(correct/(index+1) *100, 2)}")

    print(f"Attack mode is {attack_mode}")
    print(f"Correct raw value: {correct}")
    print(f"Top {k} percent: {round(correct/len(df)* 100, 2)}")

In [13]:
# No attack - originalImagnet dataset 
model, preprocess = clip.load('ViT-B/32', device)
og_df = make_adverserial_inference_df("data/originalImagenet")

In [15]:
og_df

Unnamed: 0,image,caption,label
0,n02979186_cassette_player.JPEG,This is a photo of a cassette player,cassette player
1,n02113799_standard_poodle.JPEG,This is a photo of a standard poodle,standard poodle
2,n02437312_Arabian_camel.JPEG,This is a photo of a Arabian camel,Arabian camel
3,n04154565_screwdriver.JPEG,This is a photo of a screwdriver,screwdriver
4,n12998815_agaric.JPEG,This is a photo of a agaric,agaric
...,...,...,...
995,n01582220_magpie.JPEG,This is a photo of a magpie,magpie
996,n03874599_padlock.JPEG,This is a photo of a padlock,padlock
997,n03240683_drilling_platform.JPEG,This is a photo of a drilling platform,drilling platform
998,n01795545_black_grouse.JPEG,This is a photo of a black grouse,black grouse


In [43]:
top_k_inf(model, og_df, k=5)

Correct raw: 0, Percent: 0.0
Correct raw: 9, Percent: 81.82
Correct raw: 19, Percent: 90.48
Correct raw: 28, Percent: 90.32
Correct raw: 38, Percent: 92.68
Correct raw: 46, Percent: 90.2
Correct raw: 55, Percent: 90.16
Correct raw: 62, Percent: 87.32
Correct raw: 70, Percent: 86.42
Correct raw: 78, Percent: 85.71
Correct raw: 88, Percent: 87.13
Correct raw: 98, Percent: 88.29
Correct raw: 106, Percent: 87.6
Correct raw: 118, Percent: 90.08
Correct raw: 126, Percent: 89.36
Correct raw: 136, Percent: 90.07
Correct raw: 145, Percent: 90.06
Correct raw: 154, Percent: 90.06
Correct raw: 164, Percent: 90.61
Correct raw: 174, Percent: 91.1
Correct raw: 183, Percent: 91.04
Correct raw: 192, Percent: 91.0
Correct raw: 202, Percent: 91.4
Correct raw: 211, Percent: 91.34
Correct raw: 218, Percent: 90.46
Correct raw: 228, Percent: 90.84
Correct raw: 237, Percent: 90.8
Correct raw: 247, Percent: 91.14
Correct raw: 256, Percent: 91.1
Correct raw: 263, Percent: 90.38
Correct raw: 268, Percent: 89.04


In [46]:
pgd_df = make_adverserial_inference_df("data/pgdAttack")

In [47]:
pgd_df

Unnamed: 0,image,caption,label
0,n02979186_cassette_player.JPEG,This is a photo of a cassette player,cassette player
1,n02113799_standard_poodle.JPEG,This is a photo of a standard poodle,standard poodle
2,n02437312_Arabian_camel.JPEG,This is a photo of a Arabian camel,Arabian camel
3,n04154565_screwdriver.JPEG,This is a photo of a screwdriver,screwdriver
4,n12998815_agaric.JPEG,This is a photo of a agaric,agaric
...,...,...,...
995,n01582220_magpie.JPEG,This is a photo of a magpie,magpie
996,n03874599_padlock.JPEG,This is a photo of a padlock,padlock
997,n03240683_drilling_platform.JPEG,This is a photo of a drilling platform,drilling platform
998,n01795545_black_grouse.JPEG,This is a photo of a black grouse,black grouse


In [53]:
top_k_inf(model, pgd_df, k=5, attack_mode='pgd')

  0%|▏                                                                                                                                             | 1/1000 [00:01<21:13,  1.27s/it]

Correct raw: 0, Percent: 0.0


 10%|██████████████▏                                                                                                                             | 101/1000 [02:06<18:51,  1.26s/it]

Correct raw: 68, Percent: 67.33


 20%|████████████████████████████▏                                                                                                               | 201/1000 [04:12<16:45,  1.26s/it]

Correct raw: 140, Percent: 69.65


 30%|██████████████████████████████████████████▏                                                                                                 | 301/1000 [06:18<14:41,  1.26s/it]

Correct raw: 204, Percent: 67.77


 40%|████████████████████████████████████████████████████████▏                                                                                   | 401/1000 [08:24<12:34,  1.26s/it]

Correct raw: 262, Percent: 65.34


 50%|██████████████████████████████████████████████████████████████████████▏                                                                     | 501/1000 [10:30<10:28,  1.26s/it]

Correct raw: 320, Percent: 63.87


 60%|████████████████████████████████████████████████████████████████████████████████████▏                                                       | 601/1000 [12:36<08:22,  1.26s/it]

Correct raw: 389, Percent: 64.73


 70%|██████████████████████████████████████████████████████████████████████████████████████████████████▏                                         | 701/1000 [14:42<06:16,  1.26s/it]

Correct raw: 447, Percent: 63.77


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                           | 801/1000 [16:48<04:10,  1.26s/it]

Correct raw: 501, Percent: 62.55


 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏             | 901/1000 [18:54<02:04,  1.26s/it]

Correct raw: 566, Percent: 62.82


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [20:59<00:00,  1.26s/it]

Attack mode is pgd
Correct raw value: 622
Top 5 percent: 62.2





In [54]:
fgsm_df = make_adverserial_inference_df("data/attack", fgsm=True)

In [None]:
top_k_inf(model, fgsm_df, k=5, attack_mode='fgsm')

  0%|▏                                                                                                                                              | 1/997 [00:01<21:02,  1.27s/it]

Correct raw: 0, Percent: 0.0


 10%|██████████████▎                                                                                                                              | 101/997 [02:06<18:47,  1.26s/it]

Correct raw: 44, Percent: 43.56


 20%|████████████████████████████▍                                                                                                                | 201/997 [04:12<16:42,  1.26s/it]

Correct raw: 93, Percent: 46.27


 30%|██████████████████████████████████████████▌                                                                                                  | 301/997 [06:18<14:36,  1.26s/it]

Correct raw: 148, Percent: 49.17


 40%|████████████████████████████████████████████████████████▋                                                                                    | 401/997 [08:24<12:29,  1.26s/it]

Correct raw: 193, Percent: 48.13


 50%|██████████████████████████████████████████████████████████████████████▊                                                                      | 501/997 [10:30<10:23,  1.26s/it]

Correct raw: 237, Percent: 47.31


 60%|████████████████████████████████████████████████████████████████████████████████████▉                                                        | 601/997 [12:35<08:17,  1.26s/it]

Correct raw: 276, Percent: 45.92


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                                         | 701/997 [14:41<06:12,  1.26s/it]

Correct raw: 319, Percent: 45.51


 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                           | 801/997 [16:47<04:06,  1.26s/it]

Correct raw: 359, Percent: 44.82


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 901/997 [18:53<02:00,  1.26s/it]

Correct raw: 413, Percent: 45.84


 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌   | 973/997 [20:23<00:30,  1.26s/it]