In [11]:
import matplotlib.pyplot as plt
%matplotlib inline

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms

from advertorch.utils import predict_from_logits
from advertorch_examples.utils import get_mnist_test_loader
from advertorch_examples.utils import _imshow

from tqdm import tqdm
from time import sleep

import time

from sklearn.preprocessing import normalize
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

### Load the dataset

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

###########################
from torchvision import datasets, transforms

test_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                                                           std=[0.5, 0.5, 0.5])])

test_data = datasets.ImageFolder('data/data/test', test_transforms)

test_loader = torch.utils.data.DataLoader(attack_data, batch_size=1,shuffle=True)

print (test_data.classes)

class_names = attack_data.classes



### Load the model

In [13]:
filename = "models/resnet_model_acc_95.pt"
use_cuda=True
device = torch.device('cuda')

model_ft = models.resnet18(pretrained=True).to(device)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names)).to(device)
model_ft.load_state_dict(torch.load(filename, map_location='cpu'))
model_ft.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

### Load the FE only

In [14]:
model_ft_ext = models.resnet18(pretrained=True).to(device)

### Perform Attack

In [15]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

In [70]:
def create_attack(model,device,test_loader, epsilon = 0.2):
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        
        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue
        
        # Calculate the loss
        loss = F.nll_loss(output, target)
        
        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = model(perturbed_data)
        
        # Check for success
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            #correct += 1
            # Special case for saving 0 epsilon examples
            #if (epsilon == 0) and (len(adv_examples) < 5):
            adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
            #adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
                
        else:
            # Save some adv examples for visualization later
            #if len(adv_examples) < 5:
            #adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
            adv_ex = perturbed_data
            #adv_examples.append( (init_pred.item(), final_pred.item(), data, adv_ex) )
            #cln_data = data.cpu().detach().numpy()
            return init_pred.item(), final_pred.item(), data, adv_ex 

### Prediction Analysis

In [82]:
true_peed, final_pred, true_data, adv_data = create_attack(model_ft, device, test_loader, 0.1)

In [83]:
print ('True Prediction:', class_names[true_peed], ' || Attack Prediction:', class_names[final_pred])

True Prediction: go  || Attack Prediction: goLeft


### Extrack Features of Clean and attack data

In [84]:
x = model_ft_ext(true_data).cpu().detach().numpy() 
y = model_ft_ext(adv_data.to(device)).cpu().detach().numpy() 
print (x.shape)
print (y.shape)

(1, 1000)
(1, 1000)


### Calculate the KD value

In [85]:
dist = np.linalg.norm(x - y)

In [96]:
sigma = 0.8
KD = np.exp(-dist/sigma)

In [97]:
print (KD)

0.028987346283516804


In [100]:
def getKD(x, y, sigma=0.8):
    dist = np.linalg.norm(x - y)
    return np.exp(-dist/sigma)

### Perform Searching on database

In [101]:
## For this example say we have 'go' data
## and lets compare the distance between the Go and GoLeft dabase we have so far.

In [134]:
import glob
import time
import pandas as pd

In [135]:
non_attack_go_fts = glob.glob('feature_data/non-attack/go/*')
non_attack_goLeft_fts = glob.glob('feature_data/non-attack/goLeft/*')

####  || Calculate the  KD valuse with respect to y ||

In [138]:
start_time  = time.time()

go_count = 0
for item in non_attack_go_fts:
    df=pd.read_csv(item, sep=',',header=None)
    vals = np.array(df.values)
    
    kd = getKD(vals,y,0.9)
    
    if(kd>0.04): go_count+=1

end_time = time.time()

print ('Search time : ', (end_time-start_time), '(s) || match count : ', go_count)

Search time :  83.69822072982788 (s) || match count :  871


In [139]:
start_time  = time.time()
go_left_count = 0
for item in non_attack_goLeft_fts:
    df=pd.read_csv(item, sep=',',header=None)
    vals = np.array(df.values)
    
    kd = getKD(vals,y,0.9)
    
    if(kd>0.04): go_left_count+=1
        
end_time = time.time()

print ('Search time : ', (end_time-start_time), '(s) || match count : ', go_left_count)

Search time :  11.878114938735962 (s) || match count :  119


#### || Search the whole database ||

In [153]:
def FilterByKDValue(filename,y):
    df=pd.read_csv(filename, sep=',',header=None)
    vals = np.array(df.values)
    kd = getKD(vals,y,0.9)
    if(kd>0.04): return 1
    else: return 0

In [159]:
attack_fts_count = {'go':0, 'goForward':0, 'goLeft':0, 'stop':0, 'stopLeft':0, 'warning':0, 'warningLeft':0}
non_attack_fts_count = {'go':0, 'goForward':0, 'goLeft':0, 'stop':0, 'stopLeft':0, 'warning':0, 'warningLeft':0}

def search_db(y):
    global attack_fts_count,non_attack_fts_count
    
    all_fts = glob.glob('feature_data/*/*/*')
    
    for item in all_fts:
        db_type = item.split('/')[1]
        db_sign = item.split('/')[2]
        
        #print (db_type, db_sign)
        
        if(db_type=='attack'):
            attack_fts_count[db_sign] += FilterByKDValue(item,y)
        elif (db_type=='non-attack'):
            non_attack_fts_count[db_sign] += FilterByKDValue(item,y)

In [160]:
start_time = time.time()
search_db(y)
end_time = time.time()

In [164]:
print ('Total Databse Search Time: ', (end_time-start_time), ' s')

Total Databse Search Time:  527.206205368042  s


In [165]:
attack_fts_count

{'go': 1006,
 'goForward': 2,
 'goLeft': 33,
 'stop': 1899,
 'stopLeft': 725,

In [166]:
non_attack_fts_count

{'go': 871,
 'goForward': 0,
 'goLeft': 119,
 'stop': 256,
 'stopLeft': 128,