## Whitebox augmentation demo

We use a MAITE image-classification `Augmentation` to represent a simple
adversarial attack on model input data. Because the attack depends on
gradient information that is not guaranteed to be available for MAITE
`Model` objects, we must get the information in an application-specific
way. In this case, we write the `Augmentation` implementer class such that
it has access to the underlying (framework-specific) model. This way, the
`Augmentation` implementer can access model gradients internally within its
`__call__` method and after construction the implementer can be treated as
any other implementer of `Augmentation`.

In this example, we consider the image classification domain
where input and targets from the a prediction model are both tensors.

## Setup

In [1]:
from __future__ import annotations  # permit use of tuple/dict as generic typehints in 3.8

import copy
from dataclasses import dataclass
from typing import Any, Protocol, Sequence, Tuple

import torch
from torch import nn

from maite.protocols import ArrayLike
from maite.protocols.image_classification import Augmentation

## Define a simple protocol for a broad set of attacks
This isn't strictly necessary, but helpful for extensibility.

Any interface expecting an instance of this protocol class will be
able to handle any implementer (structural subtype), so implementations
of attacks can be modified/rewritten without modifying classes
that are expected to use those objects.

In [2]:
class ImageClassifierAttack(Protocol):
    """
    Protocol defining an interface that might be satisfied by an attack on an
    image classifier.
    """

    def __call__(
        self,
        model: torch.nn.Module,
        input_batch: torch.Tensor,
        target_batch: torch.Tensor,
    ) -> torch.Tensor: ...

    @property
    def name(self) -> str: ...

## Define a simple implementer of above protocol class

In [3]:
@dataclass
class DumbAttack:
    """
    Very basic implementer of above ImageClassifierAttack protocol
    """

    name: str

    def __call__(
        self,
        model: torch.nn.Module,
        input_batch: ArrayLike,
        target_batch: ArrayLike,
    ) -> torch.Tensor:
        """
        Given a torch model, a model input batch, and a model target batch
        (i.e. ground truth) calculate an adversarial perturbation that can be
        added to the input tensor to form an adversarial input.
        """

        # type-narrow inputs to type tensor
        input_batch_tn = torch.as_tensor(input_batch)
        input_batch_tn.requires_grad = True

        # type-narrow Targets to type tensor
        target_batch_tn = torch.as_tensor(target_batch)

        preds = model(input_batch_tn)

        # calculate some simple loss
        loss = torch.sum(
            torch.nn.functional.binary_cross_entropy(preds, target_batch_tn)
        )
        loss.backward()

        assert input_batch_tn.grad is not None

        return input_batch_tn.grad * 1e-4


# TODO: Try to demonstrate a standard approach to implement protocols that permit
#       structural subclass checks at that class object level (i.e. using
#       "issubclass(SomeUserClass, SomeProtocol)" and dont require instantiation.
#       (i.e. isinstance(SomeUserClass(...), SomeProtocol). Otherwise, introspection
#       and inference tools wont be able to verify protocol compatibility without
#       instantiating. This inferrence ability is a huge potential gain.

## Define a "whitebox augmentation" class
The class will store framework-specific model while implementing MAITE
`Augmentation` protocol

In [4]:
# Create an Augmentation that takes anything satisfying this ImageClassifierAttack
# object in its constructor and uses it within its __call__ method. After it is
# constructed, the user can treat it like any other implementer of the Augmentation
# protocol.
class WhiteboxAugmentation:
    """
    Apply an image classifier attack
    """

    def __init__(self, model: torch.nn.Module, attack: ImageClassifierAttack):
        # store torch model as an attribute specific to this augmentation
        self.attack = attack
        self.model = model

    def __call__(
        self, datum: tuple[ArrayLike, ArrayLike, Sequence[dict[str, Any]]]
    ) -> tuple[torch.Tensor, torch.Tensor, Sequence[dict[str, Any]]]:
        # unpack tuple input
        input_batch, target_batch, metadata_batch = datum

        # type-narrow inputs to type tensor
        input_batch_tn = torch.as_tensor(input_batch)

        # type-narrow Targets to type tensor
        target_batch_tn = torch.as_tensor(target_batch)

        attack_perturbation = self.attack(self.model, input_batch_tn, target_batch_tn)
        input_batch_aug = input_batch_tn + attack_perturbation

        # Modify returned metadata object to record any important
        # aspects of this augmentation
        metadata_batch_aug = copy.deepcopy(metadata_batch)
        for i, datum_metadata in enumerate(metadata_batch_aug):
            if "aug_applied" not in datum_metadata.keys():
                datum_metadata["augs_applied"] = list()

                datum_metadata["augs_applied"].append(
                    {
                        "name": self.attack.name,
                        "mean_perturbation": torch.mean(attack_perturbation[i]).numpy(),
                    }
                )

        return (input_batch_aug, target_batch_tn, metadata_batch_aug)

## Test the augmentation
Create dummy torch module and batch of input/target/metadata

In [5]:
# make dummy model that takes Nx5 inputs and produces a onehot
# vector of pseudoprobabilities
BATCH_SIZE = 4
H_IMG = 32
W_IMG = 32
C_IMG = 3
N_CLASSES = 5

dummy_model = nn.Sequential(
    nn.Flatten(), nn.Linear(H_IMG * W_IMG * C_IMG, N_CLASSES), nn.ReLU(), nn.Softmax()
)

In [6]:
# Apply a WhiteboxAugmentation to a batch


# create instance of WhiteboxAugmentation class
wb_aug: Augmentation = WhiteboxAugmentation(
    model=dummy_model, attack=DumbAttack(name="silly_attack")
)

# create a 'dummy' datum batch
datum_batch: Tuple[torch.Tensor, torch.Tensor, Sequence[dict[str, Any]]] = (
    torch.rand((BATCH_SIZE, C_IMG, H_IMG, W_IMG)),
    torch.eye(BATCH_SIZE, N_CLASSES),
    [dict() for _ in range(BATCH_SIZE)],
)

# apply augmentation
datum_batch_aug = wb_aug(datum_batch)

  return self._call_impl(*args, **kwargs)


## Print result of augmentation

In [7]:
# unpack datums
# TODO: consider whether tuple of iterables or iterable of tuples is more convenient
#       as a batch format. Tuple of iterables seems to require below unpacking

model_input_batch_aug, model_target_batch_aug, md_batch_aug = datum_batch_aug
model_input_batch, model_target_batch, md_batch = datum_batch

print("Results of augmentation (by datum)")
for model_input_aug, model_target_aug, md_aug, model_input, model_target, md in zip(
    model_input_batch_aug,
    model_target_batch_aug,
    md_batch_aug,
    model_input_batch,
    model_target_batch,
    md_batch,
):
    print(f"model input:\n {model_input}")
    print(f"model input (augmented):\n {model_input_aug}")
    print(f"datum metadata:\n {md_aug}")
    print("\n")

Results of augmentation (by datum)
model input:
 tensor([[[0.7955, 0.1129, 0.7171,  ..., 0.2297, 0.3298, 0.5357],
         [0.7627, 0.5773, 0.6534,  ..., 0.8270, 0.6301, 0.5393],
         [0.0184, 0.2435, 0.0479,  ..., 0.8979, 0.2758, 0.5878],
         ...,
         [0.8196, 0.1781, 0.7873,  ..., 0.7249, 0.6869, 0.3866],
         [0.5243, 0.5222, 0.8653,  ..., 0.6402, 0.3306, 0.8031],
         [0.6204, 0.8290, 0.3151,  ..., 0.4511, 0.1459, 0.0962]],

        [[0.4418, 0.5967, 0.4572,  ..., 0.6024, 0.8852, 0.1500],
         [0.6860, 0.2836, 0.1049,  ..., 0.8003, 0.4525, 0.5554],
         [0.9274, 0.4137, 0.9634,  ..., 0.0191, 0.0723, 0.2648],
         ...,
         [0.5196, 0.3777, 0.3172,  ..., 0.9507, 0.9359, 0.0495],
         [0.5860, 0.4772, 0.5473,  ..., 0.3217, 0.1889, 0.0311],
         [0.5033, 0.2301, 0.1392,  ..., 0.2666, 0.1249, 0.5137]],

        [[0.1291, 0.7575, 0.9652,  ..., 0.4392, 0.9244, 0.2767],
         [0.3489, 0.8987, 0.5955,  ..., 0.4570, 0.4043, 0.5867],
         