# Patch Removal Demo

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from train_main import CIFAR10Classifier
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import math
from scipy.ndimage.interpolation import rotate
from AddPatch import AddPatch

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


  from scipy.ndimage.interpolation import rotate


## Load CIFAR 10 Dataset with Adversarial Patches

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    AddPatch(32, 0.05)
    ])

batch_size = 16

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img, out_path):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    
    npimg = np.transpose(npimg * 255, (1, 2, 0))
    npimg = npimg.astype("uint8")
    # plt.imshow(npimg)
    img = Image.fromarray(npimg, "RGB")
    img.save(out_path)

Files already downloaded and verified


### Plot Sample images

In [3]:
dataiter = iter(testloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images), "./plots/original_images.png")

: 

: 

## Load ResNet18 model

In [3]:

# load model
model = CIFAR10Classifier.load_from_checkpoint("experiment_results/resnet18-2022-12-08-17-10/version_0/checkpoints/epoch=49-step=54700.ckpt")
model.eval()
model.to(device)
# print(model)



CIFAR10Classifier(
  (net): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

## Test models Accuracy Without Defense

In [5]:
total_count = 0
correct_count = 0

# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(testloader):
    # Every data instance is an input + label pair
    inputs, labels = data
    labels = labels.squeeze()

    inputs = inputs.to(device)
    logits = model(inputs)
    preds = torch.argmax(logits, dim=1)

    correct = preds == labels[0]

    total_count += len(labels)
    correct_count += torch.sum(correct)


print(total_count, correct_count, correct_count / total_count)



10000 tensor(1440, device='cuda:0') tensor(0.1440, device='cuda:0')


## Test Patch Removal Defense  

In [6]:
from local_gradients_smoothing.configs.configure import Configuration
from local_gradients_smoothing.lgs.local_gradients_smoothing import LocalGradientsSmoothing

cfg = {'smoothing_factor': 2.3,
        'window_size': 5,
        'overlap': 1,
        'threshold': 0.1,
        'grad_method': "Gradient"
        }

loc_grad_smooth = LocalGradientsSmoothing(**cfg)

# threshold = 0.5

acc_list = []
threshold_list = np.linspace(0.9, 1, num=10)

for threshold in threshold_list:

    # show images
    # dataiter = iter(testloader)
    # images, labels = next(dataiter)

    # masks = torch.zeros_like(images)

    # for i in range(len(images)):
    #     grad_mask = loc_grad_smooth(images[i]).squeeze(0)

    #     grad_mask[grad_mask >= threshold] = 1
    #     grad_mask[grad_mask < threshold] = 0
        
    #     grad_mask = grad_mask.repeat((3, 1, 1))
        
    #     images[i] = images[i] * (1 - grad_mask)
    #     masks[i] = grad_mask

    # imshow(torchvision.utils.make_grid(images), "./plots/filtered_images" + str(int(threshold)).replace(".", "_") + ".png")
    # imshow(torchvision.utils.make_grid(masks), "./plots/grad_mask_images" + str(int(threshold)).replace(".", "_") + ".png")

    total_count = 0
    correct_count = 0

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    with torch.no_grad(): 
        for i, data in enumerate(testloader):
            # Every data instance is an input + label pair
            inputs, labels = data
            labels = labels.squeeze()

            for j in range(len(inputs)):
                
                grad_mask = loc_grad_smooth(inputs[j]).squeeze(0)

                grad_mask[grad_mask >= threshold] = 1
                grad_mask[grad_mask < threshold] = 0
                
                grad_mask = grad_mask.repeat((3, 1, 1))
                
                inputs[j] = inputs[j] * (1 - grad_mask)

                # # img_t = transforms.ToTensor()(inputs)
                # collage_t = torch.cat([inputs[0], grad_mask, inputs[0] * (1 - grad_mask)], dim=-1)
                # collage = transforms.ToPILImage()(collage_t)
                # collage.show()
                # result_path = cfg.get('TESTING')['result_path']
                # collage.save(result_path)
            inputs = inputs.to(device)
            logits = model(inputs).to("cpu")
            preds = torch.argmax(logits, dim=1)

            correct = preds == labels

            
            total_count += len(labels)
            correct_count += torch.sum(correct)

        print(total_count, correct_count, correct_count / total_count)
        acc_list.append(float(correct_count / total_count))

np.savez("./acc.npz", threshold_list, acc_list)

10000 tensor(7116) tensor(0.7116)
10000 tensor(7136) tensor(0.7136)
10000 tensor(7126) tensor(0.7126)
10000 tensor(7193) tensor(0.7193)
10000 tensor(7201) tensor(0.7201)
10000 tensor(7224) tensor(0.7224)
10000 tensor(7281) tensor(0.7281)
10000 tensor(7273) tensor(0.7273)
10000 tensor(7312) tensor(0.7312)
10000 tensor(7350) tensor(0.7350)


In [7]:
import numpy as np
import matplotlib.pyplot as plt

npzfile = np.load("./acc.npz")
print(npzfile["arr_1"])

plt.plot(npzfile["arr_0"], npzfile["arr_1"])
plt.show()

[0.1        0.19769999 0.2141     0.23890001 0.26390001 0.28479999
 0.30610001 0.32609999 0.3497     0.36489999]


: 

: 