In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.models import ResNet18_Weights
import torch.nn.functional as F

import numpy as np
import json

import torch.optim as optim
import torch.utils.data

import torchvision.utils
from torchvision import models
import matplotlib.pyplot as plt
import os
import random
import copy

In [2]:
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"


In [3]:
model = torchvision.models.__dict__['resnet18'](weights=ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model = nn.DataParallel(model, [0,1])
model.cuda()

checkpoint = torch.load("attack_models/model_95.pth", map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [4]:
data_dir = '../dataset/test/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

inverse_normalize = transforms.Normalize(
    mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
    std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
)

In [5]:
normal_data = torchvision.datasets.ImageFolder(data_dir, transform = transform)

subset_indices = []
for class_idx in range(len(normal_data.classes)):
    class_indices = [i for i, (_, label) in enumerate(normal_data.imgs) if label == class_idx]
    subset_indices.extend(random.sample(class_indices, 20))

In [53]:
subset_dataset = torch.utils.data.Subset(normal_data, subset_indices)
data_loader = torch.utils.data.DataLoader(subset_dataset, batch_size=1, shuffle=False, 
    num_workers=16, pin_memory=True)

In [54]:
def pgd_attack(model, image, target, eps=0.5, alpha=100/255, iters=400) :
    image, target = image.cuda(), target.cuda()
    target = target.type(torch.float)
    criterion = nn.BCEWithLogitsLoss()
        
    ori_image = image.data
        
    for i in range(iters):
        image.requires_grad = True
        
        output = model(image)

        output = output.view(1)
        model.zero_grad()
        loss = criterion(output, target).to(device)
        loss.backward()

        adv_image = image + alpha*image.grad.sign()
        eta = torch.clamp(adv_image - ori_image, min=-eps, max=eps)
        image = torch.clamp(ori_image + eta, min=0, max=1).detach_()

            
    return image

In [55]:
print("Attack Image & Predicted Label")

model.eval()

correct_preds = 0
total = 0

for image, target in data_loader:
    
    to_save = copy.deepcopy(image)
    to_save = to_save.squeeze(0)
    to_save = inverse_normalize(to_save)
    to_save = transforms.ToPILImage()(to_save)
    to_save.save("results/orig_{}.png".format(total))

    image = pgd_attack(model, image, target)

    target = target.to(device)
    target = target.type(torch.float)

    output = model(image)
    output = output.view(1)
    
    y_pred = torch.round(torch.sigmoid(output))
    correct_preds += torch.sum(y_pred == target).cpu()

    total += 1
    print("Completed : {}".format(total))

    image = image.squeeze(0)
    image = inverse_normalize(image)
    image = transforms.ToPILImage()(image)
    image.save("results/attack_{}.png".format(total))
    
print('Accuracy of test text: %f %%' % (100 * float(correct_preds) / len(data_loader.dataset)))

Attack Image & Predicted Label
Completed : 1
Completed : 2
Completed : 3
Completed : 4
Completed : 5
Completed : 6
Completed : 7
Completed : 8
Completed : 9
Completed : 10
Completed : 11
Completed : 12
Completed : 13
Completed : 14
Completed : 15
Completed : 16
Completed : 17
Completed : 18
Completed : 19
Completed : 20
Completed : 21
Completed : 22
Completed : 23
Completed : 24
Completed : 25
Completed : 26
Completed : 27
Completed : 28
Completed : 29
Completed : 30
Completed : 31
Completed : 32
Completed : 33
Completed : 34
Completed : 35
Completed : 36
Completed : 37
Completed : 38
Completed : 39
Completed : 40
Accuracy of test text: 5.000000 %
