In [1]:
import os
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
from models import ResNet50

from PIL import Image
from skimage.measure import compare_ssim

from tqdm import tqdm_notebook

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
ssim_threshold = 0.95

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

In [4]:
source_path = "./data/imgs"
target_path = "./data/to"

In [5]:
def get_model():
    net = ResNet50()
    net.load_state_dict(torch.load("best_model_chkpt-resnet50.t7", map_location='cpu')['net'])
    return net

In [6]:
def tensor2img(tensor):
    tensor_copy = tensor.clone()
    for t, m, s in zip(tensor_copy, MEAN, STD):
        t.div_(1/s).sub_(-m)
    tensor_copy[tensor_copy > 1] = 1
    tensor_copy[tensor_copy < 0] = 0
    tensor_copy = tensor_copy.squeeze(0)
    return transforms.ToPILImage()(tensor_copy)

In [7]:
class FGSM():
    def __init__(self, model, eps, max_iter=50):
        self.model = model
        self.model.eval()
        
        self.eps = eps
        self.max_iter = max_iter
        
        self.transform = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.Scale(112),
            transforms.ToTensor(),
            transforms.Normalize(mean=MEAN, std=STD),
        ])
        self.cropping = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.Scale(112)
        ])
        self.img2tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=MEAN, std=STD)
        ])
        
        self.loss = nn.MSELoss()

    def attack(self, attack_pairs):
        target_vectors = np.ones((len(attack_pairs['target']), 512), dtype=np.float32)

        for id_, img_name in enumerate(attack_pairs['target']):
            target_vectors[id_] = self.model(
                Variable(self.transform(
                    Image.open(os.path.join(source_path, img_name))
                ).unsqueeze(0), requires_grad=False)
            ).data.numpy().squeeze()

        for img_name in attack_pairs['source']:
            if os.path.isfile(os.path.join(target_path, img_name)):
                continue

            img = Image.open(os.path.join(source_path, img_name))
            original_img = self.cropping(img)
            attacked_img = original_img
            input_var = Variable(self.transform(img).unsqueeze(0), requires_grad=True)
                                 
            for iter_number in tqdm_notebook(range(self.max_iter)):
                adv_noise = torch.zeros((3,112,112))

                for target_vec in target_vectors:
                    target_out = Variable(
                        torch.from_numpy(target_vec).unsqueeze(0), 
                        requires_grad=False
                    )

                    input_var.grad = None
                    out = self.model(input_var)
                    calc_loss = self.loss(out, target_out)
                    calc_loss.backward()
                    noise = self.eps * torch.sign(input_var.grad.data).squeeze()
                    adv_noise = adv_noise + noise

                input_var.data = input_var.data - adv_noise
                changed_img = tensor2img(input_var.data.squeeze())

                ssim = compare_ssim(np.array(original_img, dtype=np.float32), 
                                    np.array(changed_img, dtype=np.float32), 
                                    multichannel=True)
                if ssim < ssim_threshold:
                    break
                else:
                    attacked_img = changed_img

            attacked_img.save(os.path.join(target_path, img_name.replace('.jpg', '.png')))

In [10]:
attacker = FGSM(get_model(), eps=1e-3, max_iter=10)
img_pairs = pd.read_csv("pairs_list.csv")

for id_ in tqdm_notebook(img_pairs.index.values[:10]):
    attacker.attack({
        'source': img_pairs.loc[id_].source_imgs.split('|'),
        'target': img_pairs.loc[id_].target_imgs.split('|')
    })

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))


