First, let's load some libraries, load a pig tensor, and create a noise tensor.

In [6]:
from matplotlib import pyplot as plt
from pickle import loads, dumps
from PIL import Image
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
import json

with open("imagenet_class_index.json") as f:
    imagenet_classes = {int(i): x[1] for i, x in json.load(f).items()}

pig_img = Image.open("pig.jpg")
preprocess = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)
pig_tensor = preprocess(pig_img)[None, :, :, :]
gaussian_noise_tensor = torch.randn(1, 3, 224, 224)

Next, let's load the model and create some target classes.

In [7]:
# simple Module to normalize an image
class Normalize(torch.nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.Tensor(mean)
        self.std = torch.Tensor(std)

    def forward(self, x):
        return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[
            None, :, None, None
        ]


norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()

epsilon = 2.0 / 255
target_classes = [
    49,
    272,
    276,
    286,
    287,
    288,
    290,
    291,
    292,
    293,
    340,
    344,
    366,
    367,
    372,
    386,
]
target_class_names = [imagenet_classes[cls] for cls in target_classes]

Next, define the PGD algorithm and a function to plot stuff.

In [8]:
def pgd(image, target_classes, file_name, epsilon=epsilon, lr=1e-1, threshold=0.9):
    try:
        with open(file_name, "rb") as f:
            return loads(f.read())
    except FileNotFoundError:
        deltas = []
        iters = []

        for cls in target_classes:
            delta = torch.zeros_like(image, requires_grad=True)
            opt = torch.optim.SGD([delta], lr=lr)
            # opt = torch.optim.Adam([delta], lr=lr)
            step = 0
            prob = 0

            while prob < threshold:
                pred = model(norm((image + delta).clip(0, 1)))
                prob = torch.nn.Softmax(dim=1)(pred)[0][cls].item()
                print(f"Class: {cls} Prob: {prob} Step: {step}")
                loss = torch.nn.CrossEntropyLoss()(
                    pred,
                    torch.LongTensor([cls])
                )

                opt.zero_grad()
                loss.backward()
                opt.step()
                delta.data.clamp_(-epsilon, epsilon)

                step += 1

            deltas.append(delta)
            iters.append(step)

        with open(file_name, "wb") as f:
            f.write(dumps((deltas, iters)))

        return deltas, iters


def plot(images, target_class_names, file_name=None):
    plt.clf()
    plt.figure(figsize=(6, 8))
    for i, (image, target_class) in enumerate(zip(images, target_class_names)):
        plt.subplot(4, 4, i + 1)
        plt.xticks([], [])
        plt.yticks([], [])
        plt.title(f"{target_class.lower().replace('_', ' ')}")
        plt.imshow(image, cmap="gray")
    plt.tight_layout()
    if file_name:
        plt.savefig(file_name, transparent=True)
    else:
        plt.show()

Attack and plot.

In [10]:
deltas, iters = pgd(pig_tensor, target_classes, "targeted_attack_pig_deltas_iters")
plot(
    [(pig_tensor + a).clip(0, 1)[0].detach().numpy().transpose(1, 2, 0)
     for a in deltas],
    target_class_names,
    "pig_misclassified.png"
)
plt.clf()
plt.imshow((deltas[0][0] * 50).detach().numpy().transpose(1, 2, 0))
plt.savefig('delta')
print(deltas[0].norm()) # Lower norm probably due to lower learning rate

deltas, iters = pgd(
    gaussian_noise_tensor,
    target_classes,
    "targeted_attack_gaussian_noise_deltas_iters",
    epsilon=1,
    lr=10, # Set lr=1 for uniformly distributed noise
)
plot(
    [
        (gaussian_noise_tensor + a).clip(0, 1)[0]
        .detach().numpy().transpose(1, 2, 0)
        for a in deltas
    ],
    target_class_names,
    "gaussian_noise_misclassified.png"
)
print(deltas[0].norm()) # Higher norm probably due to higher learning rate

Class: 49 Prob: 0.0006990268593654037 Step: 0
Class: 49 Prob: 0.0012487226631492376 Step: 1
Class: 49 Prob: 0.0017028694273903966 Step: 2
Class: 49 Prob: 0.0019656638614833355 Step: 3
Class: 49 Prob: 0.002621763851493597 Step: 4


KeyboardInterrupt: 