In [1]:
import glob

import numpy as np
from matplotlib import image as mpimg
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import TensorDataset
from torch import optim

In [2]:
@torch.no_grad()
def calculate_accuracy(model, dataloader, device='cpu'):
    model.eval()

    correct = 0
    total = 0
    for data in dataloader:
        images, labels = data

        images = images.unsqueeze(1).to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return correct / total

In [3]:
class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 2, dropout: float = 0.2) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
def create_classification_dataset(path,
                                  crops_per_negative=3,
                                  include_mirrors=False):

    def extract_patch(image, x_d=64, y_d=128):
        x = np.random.randint(0, image.shape[1] - x_d)
        y = np.random.randint(0, image.shape[0] - y_d)
        patch = image[y: y + y_d, x: x + x_d]

        return patch

    x = []
    y = []

    for filepath in glob.iglob(f'{path}/negatives/*.png'):
        img = mpimg.imread(filepath)
        for i in range(crops_per_negative):
            p = extract_patch(img)
            x.append(p)
            y.append(0)

    for filepath in glob.iglob(f'{path}/positives/*.png'):
        if 'mirr0r' in filepath and not include_mirrors:
            pass
        img = mpimg.imread(filepath)
        x.append(img)
        y.append(1)

    return np.asarray(x), np.asarray(y)

In [5]:
x, y = create_classification_dataset('dataset/asl_eth_flir/flir_17_Sept_2013/train', 100)

In [6]:
x = x / x.max()

In [7]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

In [8]:
x_train = torch.tensor(x_train)
x_test = torch.tensor(x_test)

y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

In [9]:
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

In [10]:
batch_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [11]:
model = AlexNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        model.train()
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.unsqueeze(1).to(device)
        labels = labels.to(device)
        # print(inputs.shape)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        # print(labels, outputs.argmax(-1), loss)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
            print(calculate_accuracy(model, trainloader, device=device))
            print(calculate_accuracy(model, testloader, device=device))

print('Finished Training')

[1,   100] loss: 0.032
0.7428523092539333
0.7324086603518268
[1,   200] loss: 0.023
0.8917272881069193
0.8778755074424899
[1,   300] loss: 0.015
0.8934190492302487
0.8900541271989174
[1,   400] loss: 0.016
0.8555236000676705
0.8521650879566982
[1,   500] loss: 0.015
0.9189646421925224
0.9191474966170501
[1,   600] loss: 0.010
0.9328370834038234
0.9316644113667117
[1,   700] loss: 0.013
0.8602605312129927
0.8616373477672531
[2,   100] loss: 0.015
0.8993402131619015
0.892083897158322
[2,   200] loss: 0.012
0.8995939773304009
0.9018944519621109
[2,   300] loss: 0.011
0.9510235154796143
0.9502706359945873
[2,   400] loss: 0.009
0.8743021485366266
0.8633288227334236
[2,   500] loss: 0.007
0.9613432583319236
0.9573748308525034
[2,   600] loss: 0.011
0.9300456775503299
0.9272665764546685
[2,   700] loss: 0.010
0.9566063271866012
0.9566982408660352
Finished Training


In [12]:
from itertools import chain

import numpy as np
import torch
from torch.nn.functional import softmax
from torchattacks.attack import Attack


class Pixle(Attack):
    r"""
    Pixle: a fast and effective black-box attack based on rearranging pixels'
    [https://arxiv.org/abs/2202.02236]

    Distance Measure : L0

    Arguments:
        model (nn.Module): model to attack.
        x_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther x side for each iteration. The integers are considered as fixed number of size,
        while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10))
        y_dimensions (int or float, or a tuple containing a combination of those): size of the sampled patch along ther y side for each iteration. The integers are considered as fixed number of size,
        while the float as parcentage of the size. A tuple is used to specify both under and upper bound of the size. (Default: (2, 10))
        pixel_mapping (str): the type of mapping used to move the pixels. Can be: 'random', 'similarity', 'similarity_random', 'distance', 'distance_random' (Default: random)
        restarts (int): the number of restarts that the algortihm performs. (Default: 20)
        max_iterations (int): number of iterations to perform for each restart. (Default: 100)
        update_each_iteration (bool): if the attacked images must be modified after each iteration (True) or after each restart (False).  (Default: False)
    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.

    Examples::
        >>> attack = torchattacks.Pixle(model, x_dimensions=(0.1, 0.2), restarts=100, iteratsion=50)
        >>> adv_images = attack(images, labels)
    """
    def __init__(self, model, x_dimensions=(2, 10), y_dimensions=(2, 10),
                 pixel_mapping='random', restarts=20,
                 max_iterations=100, update_each_iteration=False):
        super().__init__("Pixle", model)

        if restarts < 0 or not isinstance(restarts, int):
            raise ValueError('restarts must be and integer >= 0 '
                             '({})'.format(restarts))

        self.update_each_iteration = update_each_iteration
        self.max_patches = max_iterations

        self.restarts = restarts
        self.pixel_mapping = pixel_mapping.lower()

        if self.pixel_mapping not in ['random', 'similarity',
                                      'similarity_random', 'distance',
                                      'distance_random']:
            raise ValueError('pixel_mapping must be one of [random, similarity,'
                             'similarity_random, distance, distance_random]'
                             ' ({})'.format(self.pixel_mapping))

        if isinstance(y_dimensions, (int, float)):
            y_dimensions = [y_dimensions, y_dimensions]

        if isinstance(x_dimensions, (int, float)):
            x_dimensions = [x_dimensions, x_dimensions]

        if not all([(isinstance(d, (int)) and d > 0)
                    or (isinstance(d, float) and 0 <= d <= 1)
                    for d in chain(y_dimensions, x_dimensions)]):
            raise ValueError('dimensions of first patch must contains integers'
                             ' or floats in [0, 1]'
                             ' ({})'.format(y_dimensions))

        self.p1_x_dimensions = x_dimensions
        self.p1_y_dimensions = y_dimensions

        self._supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        if not self.update_each_iteration:
            return self.restart_forward(images, labels)
        else:
            return self.iterative_forward(images, labels)

    def restart_forward(self, images, labels):
        assert len(images.shape) == 3 or \
               (len(images.shape) == 4 and images.size(0) == 1)

        if len(images.shape) == 3:
            images = images.unsqueeze(0)

        if self._targeted:
            labels = self._get_target_label(images, labels)

        x_bounds = tuple(
            [max(1, d if isinstance(d, int) else round(images.size(3) * d))
             for d in self.p1_x_dimensions])

        y_bounds = tuple(
            [max(1, d if isinstance(d, int) else round(images.size(2) * d))
             for d in self.p1_y_dimensions])

        adv_images = []

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        bs, _, _, _ = images.shape

        for idx in range(bs):
            image, label = images[idx:idx + 1], labels[idx:idx + 1]

            best_image = image.clone()
            pert_image = image.clone()

            loss, callback = self._get_fun(image, label,
                                           target_attack=self._targeted)
            best_solution = None

            best_p = loss(solution=image, solution_as_perturbed=True)
            image_probs = [best_p]

            it = 0

            for r in range(self.restarts):
                stop = False

                for it in range(self.max_patches):

                    (x, y), (x_offset, y_offset) = \
                        self.get_patch_coordinates(image=image,
                                                   x_bounds=x_bounds,
                                                   y_bounds=y_bounds)

                    destinations = self.get_pixel_mapping(image, x, x_offset,
                                                          y, y_offset,
                                                          destination_image=
                                                          best_image)

                    solution = [x, y, x_offset, y_offset] + destinations

                    pert_image = self._perturb(source=image,
                                               destination=best_image,
                                               solution=solution)

                    p = loss(solution=pert_image,
                             solution_as_perturbed=True)

                    if p < best_p:
                        best_p = p
                        best_solution = pert_image

                    image_probs.append(best_p)

                    if callback(pert_image, None, True):
                        best_solution = pert_image
                        stop = True
                        break

                if best_solution is None:
                    best_image = pert_image
                else:
                    best_image = best_solution

                if stop:
                    break

            adv_images.append(best_image)

        adv_images = torch.cat(adv_images)

        return adv_images

    def iterative_forward(self, images, labels):
        assert len(images.shape) == 3 or \
               (len(images.shape) == 4 and images.size(0) == 1)

        if len(images.shape) == 3:
            images = images.unsqueeze(0)

        if self._targeted:
            labels = self._get_target_label(images, labels)

        x_bounds = tuple(
            [max(1, d if isinstance(d, int) else round(images.size(3) * d))
             for d in self.p1_x_dimensions])

        y_bounds = tuple(
            [max(1, d if isinstance(d, int) else round(images.size(2) * d))
             for d in self.p1_y_dimensions])

        adv_images = []

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        bs, _, _, _ = images.shape

        for idx in range(bs):
            image, label = images[idx:idx + 1], labels[idx:idx + 1]

            best_image = image.clone()

            loss, callback = self._get_fun(image, label,
                                           target_attack=self._targeted)

            best_p = loss(solution=image, solution_as_perturbed=True)
            image_probs = [best_p]

            for it in range(self.max_patches):

                (x, y), (x_offset, y_offset) = \
                    self.get_patch_coordinates(image=image,
                                               x_bounds=x_bounds,
                                               y_bounds=y_bounds)

                destinations = self.get_pixel_mapping(image, x, x_offset,
                                                      y, y_offset,
                                                      destination_image=best_image)

                solution = [x, y, x_offset, y_offset] + destinations

                pert_image = self._perturb(source=image,
                                           destination=best_image,
                                           solution=solution)

                p = loss(solution=pert_image, solution_as_perturbed=True)

                if p < best_p:
                    best_p = p
                    best_image = pert_image

                image_probs.append(best_p)

                if callback(pert_image, None, True):
                    best_image = pert_image
                    break

            adv_images.append(best_image)

        adv_images = torch.cat(adv_images)

        return adv_images

    def _get_prob(self, image):
        out = self.model(image.to(self.device))
        prob = softmax(out, dim=1)
        return prob.detach().cpu().numpy()

    def loss(self, img, label, target_attack=False):

        p = self._get_prob(img)
        p = p[np.arange(len(p)), label]

        if target_attack:
            p = 1 - p

        return p.sum()

    def get_patch_coordinates(self, image, x_bounds, y_bounds):
        c, h, w = image.shape[1:]

        x, y = np.random.uniform(0, 1, 2)

        x_offset = np.random.randint(x_bounds[0],
                                     x_bounds[1] + 1)

        y_offset = np.random.randint(y_bounds[0],
                                     y_bounds[1] + 1)

        x, y = int(x * (w - 1)), int(y * (h - 1))

        if x + x_offset > w:
            x_offset = w - x

        if y + y_offset > h:
            y_offset = h - y

        return (x, y), (x_offset, y_offset)

    def get_pixel_mapping(self, source_image, x, x_offset, y, y_offset,
                          destination_image=None):
        if destination_image is None:
            destination_image = source_image

        destinations = []
        c, w, h = source_image.shape[1:]
        source_image = source_image[0]

        if self.pixel_mapping == 'random':
            for i in range(x_offset):
                for j in range(y_offset):
                    dx, dy = np.random.uniform(0, 1, 2)
                    dx, dy = int(dx * (w - 1)), int(dy * (h - 1))
                    destinations.append([dx, dy])
        else:
            for i in np.arange(y, y + y_offset):
                for j in np.arange(x, x + x_offset):
                    pixel = source_image[:, i: i + 1, j: j + 1]
                    diff = destination_image - pixel
                    diff = diff[0].abs().mean(0).view(-1)

                    if 'similarity' in self.pixel_mapping:
                        diff = 1 / (1 + diff)
                        diff[diff == 1] = 0

                    probs = torch.softmax(diff, 0).cpu().numpy()

                    indexes = np.arange(len(diff))

                    pair = None

                    linear_iter = iter(sorted(zip(indexes, probs),
                                              key=lambda pit: pit[1],
                                              reverse=True))

                    while True:
                        if 'random' in self.pixel_mapping:
                            index = np.random.choice(indexes, p=probs)
                        else:
                            index = next(linear_iter)[0]

                        _y, _x = np.unravel_index(index, (h, w))

                        if _y == i and _x == j:
                            continue

                        pair = (_x, _y)
                        break

                    destinations.append(pair)

        return destinations

    def _get_fun(self, img, label, target_attack=False):
        img = img.to(self.device)

        if isinstance(label, torch.Tensor):
            label = label.cpu().numpy()

        @torch.no_grad()
        def func(solution,
                 destination=None,
                 solution_as_perturbed=False, **kwargs):

            if not solution_as_perturbed:
                pert_image = self._perturb(source=img,
                                           destination=destination,
                                           solution=solution)
            else:
                pert_image = solution

            p = self._get_prob(pert_image)
            p = p[np.arange(len(p)), label]

            if target_attack:
                p = 1 - p

            return p.sum()

        @torch.no_grad()
        def callback(solution,
                     destination=None,
                     solution_as_perturbed=False,
                     **kwargs):

            if not solution_as_perturbed:
                pert_image = self._perturb(source=img,
                                           destination=destination,
                                           solution=solution)
            else:
                pert_image = solution

            p = self._get_prob(pert_image)[0]
            mx = np.argmax(p)

            if target_attack:
                return mx == label
            else:
                return mx != label

        return func, callback

    def _perturb(self, source, solution, destination=None):
        if destination is None:
            destination = source

        c, h, w = source.shape[1:]

        x, y, xl, yl = solution[:4]
        destinations = solution[4:]

        source_pixels = np.ix_(range(c),
                               np.arange(y, y + yl),
                               np.arange(x, x + xl))

        indexes = torch.tensor(destinations)
        destination = destination.clone().detach().to(self.device)

        s = source[0][source_pixels].view(c, -1)

        destination[0, :, indexes[:, 0], indexes[:, 1]] = s

        return destination


In [13]:
len(test_dataset)

2956

In [15]:
from collections import defaultdict
from torch import softmax

model.eval()
pixle = Pixle(x_dimensions=1, y_dimensions=1, model=model, max_iterations=100, restarts=100)

correct = 0
total = 0

classes_total = defaultdict(int)
classes_attacked = defaultdict(int)

with torch.no_grad():
    for img, y in torch.utils.data.DataLoader(test_dataset,
                                               batch_size=1,
                                               shuffle=False):
        if total == 50:
            break

        img = img.unsqueeze(1).to(device)

        probs = softmax(model(img), dim=1)[0].cpu()
        pred = np.argmax(probs).item()

        if pred != y.item():
            print('skipped')
            continue

        pert_images = pixle(img, y)
        final_prob = softmax(model(pert_images), dim=1)[0].cpu()

        pred = np.argmax(final_prob).item()

        if pred == y.item():
            correct += 1
        else:
            classes_attacked[y.item()] += 1

        classes_total[y.item()] += 1
        total += 1

        print(correct, total, correct / total,  y.item(), pred)

    print(correct, total)

0 1 0.0 1 0
0 2 0.0 1 0
1 3 0.3333333333333333 0 0
2 4 0.5 0 0
2 5 0.4 1 0
3 6 0.5 0 0
3 7 0.42857142857142855 1 0
4 8 0.5 1 1
4 9 0.4444444444444444 1 0
5 10 0.5 1 1
5 11 0.45454545454545453 1 0
6 12 0.5 0 0
7 13 0.5384615384615384 0 0
7 14 0.5 0 1
8 15 0.5333333333333333 0 0
9 16 0.5625 0 0
10 17 0.5882352941176471 0 0
10 18 0.5555555555555556 1 0
11 19 0.5789473684210527 0 0
11 20 0.55 1 0
11 21 0.5238095238095238 0 1
12 22 0.5454545454545454 0 0
12 23 0.5217391304347826 1 0
13 24 0.5416666666666666 0 0
skipped
14 25 0.56 0 0
15 26 0.5769230769230769 0 0
16 27 0.5925925925925926 0 0
16 28 0.5714285714285714 1 0
16 29 0.5517241379310345 1 0
17 30 0.5666666666666667 0 0
18 31 0.5806451612903226 0 0
19 32 0.59375 0 0
20 33 0.6060606060606061 0 0
20 34 0.5882352941176471 1 0
21 35 0.6 0 0
21 36 0.5833333333333334 1 0
22 37 0.5945945945945946 0 0
skipped
23 38 0.6052631578947368 0 0
24 39 0.6153846153846154 1 1
24 40 0.6 1 0
skipped
25 41 0.6097560975609756 0 0
25 42 0.5952380952380952 0

In [None]:
classes_total

In [None]:
classes_attacked