In [None]:
import sys
sys.path.append('..')

# Helper imports
from data import CIFAR10, IMAGENETTE
import utils

# Numerical computing and display imports
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import numpy as np

# Other imports
import multiprocess as mp
from datetime import datetime
import logging
import os

In [None]:
# Load and Prepare the Data
folder_path = '../../images/'
dataset = 'imagenette'    
if dataset == 'cifar10':
    images = CIFAR10()
if dataset == 'imagenette':
    images = IMAGENETTE()
x = images.load()
images.save_to_disk(x, folder_path, num_images=100)
folder_path = utils.move_data_to_temp_ram(folder_path, ram_size_mb=50)
del(x)

In [None]:
# Threshold-hyperparameter mappings
threshold_hyperparam = {
    '1':  {'simba_hamming_threshold': 7,  'simba_l2_threshold': 20, 'hsja_l2_threshold': 8},
    '5':  {'simba_hamming_threshold': 10, 'simba_l2_threshold': 20, 'hsja_l2_threshold': 16},
    '10': {'simba_hamming_threshold': 20, 'simba_l2_threshold': 28, 'hsja_l2_threshold': 24},
    '19': {'simba_hamming_threshold': 23, 'simba_l2_threshold': 30, 'hsja_l2_threshold': 24}
}
attack_hamming_threshold = int(0.1 * 96 + 0.5) 

# SIMBA Hyperparams
simba_epsilon = 0.9
simba_hamming_threshold = threshold_hyperparam[str(attack_hamming_threshold)]['simba_hamming_threshold']                                # <= 5% - 10(max) & 10% - 20 & 20% - 23 
simba_l2_threshold = threshold_hyperparam[str(attack_hamming_threshold)]['simba_l2_threshold']
simba_max_steps = 10000 
fast = True

# HSJA Hyperparams
hsja_max_steps = 10
hsja_grad_queries = 20
hsja_l2_threshold = threshold_hyperparam[str(attack_hamming_threshold)]['hsja_l2_threshold']

# Other params
ssim_threshold = 0.8
now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S")

In [None]:
# Threshold-hyperparameter mappings
nes_threshold_hyperparam = {
    '1':  {'nes_l2_threshold': 25, 'nes_l2_tolerance': 10, 'hsja_l2_threshold': 10},
    '5':  {'nes_l2_threshold': 35, 'nes_l2_tolerance': 12, 'hsja_l2_threshold': 15},
    '10': {'nes_l2_threshold': 45, 'nes_l2_tolerance': 15, 'hsja_l2_threshold': 20},
    '19': {'nes_l2_threshold': 55, 'nes_l2_tolerance': 20, 'hsja_l2_threshold': 24}# l2_tolerance: 15 ... hsja_l2_threshold: 25
}
attack_hamming_threshold = int(0.1 * 96 + 0.5) 

# NES Hyperparams
nes_mean =  -0.025 
nes_std = 0.1
nes_sigma = 0.7
nes_eps = 0.1                      
nes_l2_threshold = nes_threshold_hyperparam[str(attack_hamming_threshold)]['nes_l2_threshold'] 
nes_l2_tolerance = nes_threshold_hyperparam[str(attack_hamming_threshold)]['nes_l2_tolerance'] 
nes_num_samples = 50

In [None]:
# Soft-Label Attack imports
from simba import SimBAttack
from nes import NESAttack

simba = SimBAttack(eps=simba_epsilon,
                   hamming_threshold=simba_hamming_threshold,
                   l2_threshold=simba_l2_threshold,
                   max_steps=simba_max_steps,
                   fast=fast)
nes = NESAttack(mean=nes_mean,
                std=nes_std,
                sigma=nes_sigma,
                eps=nes_eps,
                l2_threshold=nes_l2_threshold,
                l2_tolerance=nes_l2_tolerance,
                num_samples=nes_num_samples)

In [None]:
# Hard-label Attack imports
from hsja import HSJAttack

hsja = HSJAttack(max_iters=hsja_max_steps, 
                 grad_queries=hsja_grad_queries, 
                 l2_threshold=hsja_l2_threshold, 
                 hamming_threshold=attack_hamming_threshold)

In [None]:
# Joint Attack
from joint_attack import JointAttack

joint_attack = JointAttack(nes, hsja)

In [None]:
for i in range(100):
    # Format the path to the input image
    img_path = f'../../images/{i+1}.bmp' 
    _, _, _, _, path, filetype = img_path.split('.')
    img_path = path.split('/')
    img_path = f'{folder_path}{img_path[2]}.{filetype}'

    # Attack NeuralHash 
    orig_img, sl_img, adv_img, sl_num_queries, hl_num_queries = joint_attack.attack(img_path)
    
    # Save the simba image and the final image
    utils.save_img(f'../../images/{i+1}_simba.bmp', sl_img)  
    utils.save_img(f'../../images/{i+1}_final.bmp', adv_img)

    # Attack Metrics
    orig_hash  = utils.compute_hash(orig_img)
    sl_hash = utils.compute_hash(sl_img)
    adv_hash = utils.compute_hash(adv_img)
    simba_hamming_dist = utils.distance(orig_hash, sl_hash, "hamming")
    final_hamming_dist = utils.distance(orig_hash, adv_hash, "hamming")
    sl_l2_dist = utils.distance(orig_img, sl_img, 'l2')
    final_l2_dist = utils.distance(orig_img, adv_img, 'l2')
    final_ssim = utils.distance(orig_img, adv_img.astype(np.uint8), 'ssim')
    total_queries = sl_num_queries + hl_num_queries
    success = (final_hamming_dist >= attack_hamming_threshold) and (final_ssim >= ssim_threshold)

    attack_metrics = {
        'Image Path':         [img_path],
        'Success':            [success],
        'Queries':            [total_queries],
        'Soft-Label L2':      [sl_l2_dist],
        'Final L2':           [final_l2_dist],
        'Final SSIM':         [final_ssim],
        'Soft-Label Hamming': [simba_hamming_dist],
        'Final Hamming':      [final_hamming_dist]
    }

    # Save the results
    df = pd.DataFrame.from_dict(attack_metrics)
    sl_attack = 'simba' if type(joint_attack.sl_algo) == SimBAttack else 'nes'
    eps = simba_epsilon if sl_attack == 'simba' else nes_eps
    file_path = f'metrics/{sl_attack}/{eps}/hamm_{attack_hamming_threshold}_l2_{hsja_l2_threshold}_{dt_string}.csv'
    if os.path.exists(file_path):
        df.to_csv(file_path, mode='a', index=False, header=False)
    else: 
        df.to_csv(file_path, index=False, header=True)    

In [None]:
def plot_results(orig_img, sl_img, adv_img, hamming_threshold):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,6))
    #fig.suptitle(f'Hamming Threshold: {hamming_threshold}')
    ax1.imshow(orig_img)
    ax1.set_title('Orginal Image') 
    ax2.imshow(sl_img)
    ax2.set_title('Noisy Simba Image')
    ax3.imshow(adv_img.astype(np.uint8))
    ax3.set_title('Denoised Final Image')

In [None]:
# Show the images/results
plot_results(orig_img, sl_img, adv_img, attack_hamming_threshold)