# Demo of image classification protocols

This is an end-to-end demonstration of the MAITE protocols using "dummy" implementations of the various components.

It shows how a component implementer/provider might wrap a component to conform to MAITE protocols.

It also shows how components can be used in MAITE workflows (`predict` and `evaluate`).

## Set up

In [1]:
from __future__ import annotations
import copy
import torch
import torch.utils.data

from torch import nn
from typing import Any, Optional, Sequence

In [2]:
# import MAITE protocols that don't depend on machine learning task type
from maite.protocols import ArrayLike

# import versions of generic MAITE component protocols specialized to image classification task
from maite.protocols.image_classification import (
    Augmentation,
    Dataset,
    DataLoader,
    InputBatchType,
    MetadataBatchType,
    Metric,
    Model,
    OutputBatchType
)

In [3]:
# verify Pylance configured properly in vscode --> should see red squiggle under "abc"
an_int: int = "abc"

## Model

In [4]:
class ScaledSoftmax(nn.Module):
    def __init__(self, temperature: float):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, x):
        return torch.softmax(x / self.temperature, dim=1)

class MyModel:
    def __init__(self, name: str, num_classes: int, device: str):
        self.name = name
        self.device = device

        # mimick a single-layer feedforward NN
        in_features = 3 * 32 * 32
        out_features = num_classes
        self._model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, out_features, bias=False),
            #nn.Softmax(dim=1)
            ScaledSoftmax(0.1)
        )

        # overwrite weights in Linear layer to make deterministic
        self._model[1].weight.data = 0.01 * torch.randn(
            out_features,
            in_features,
            generator=torch.Generator().manual_seed(42)
        )

        # set to evaluate mode
        self._model.eval()

        # move to device
        self._model.to(device)

    def __call__(self, input: ArrayLike) -> torch.Tensor:
        # tensor library bridging
        xb = torch.as_tensor(input)

        # make sure is batch
        assert(xb.ndim == 4)

        # move data to device
        xb = xb.to(self.device)

        # apply model
        output = self._model(xb)
        return output.detach()
    
    @property
    def metadata(self) -> dict[str, Any]:
        return dict(name=self.name)

In [5]:
# verify that model conforms to `Model` protocol
model: Model = MyModel("mymodel1", 10, "cpu")

In [6]:
# test determinism
model = MyModel("mymodel1", 10, "cpu")
x = torch.rand(3, 32, 32, generator=torch.Generator().manual_seed(12345678))
output = model(x.unsqueeze(0)) # convert to batch, which model expects
output = torch.as_tensor(output)
torch.round(output * 1e3) / 1e3

tensor([[0.0150, 0.0270, 0.0060, 0.4170, 0.0040, 0.0000, 0.4390, 0.0000, 0.0000,
         0.0920]])

## Dataset

In [7]:
# Q: maybe having map-style dataset protocol as well would be handy?
# - since easier to directly implement (although i might just be missing some obvious way to do it)
# - and since probably can use either map-style or iterable-style in `predict` and `evaluate`

In [8]:
"""
# dataset that mimicks CIFAR-10 w/inputs of shape (3, 32, 32)
# target class/label for image i is `i % num_classes`

# see: https://realpython.com/python-iterators-iterables/
class MyIterator(collections.abc.Iterator):
    def __init__(self, sequence: Sequence):
        self.sequence = sequence
        self.i = 0

    def __next__(self):
        if self.i < len(self.sequence):
            x = self.sequence[self.i]
            self.i += 1
            return x
        else:
            raise StopIteration

class MyDataset(Sequence):
    def __init__(self, name: str, num_classes: int, num_items: int):
        self.name = name
        self.num_classes = num_classes
        self.num_items = num_items

    def __len__(self) -> int:
        return self.num_items

    def __getitem__(
        self, i: int
    ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
        assert i < self.num_items
        input = i * torch.ones(3, 32, 32)
        output = torch.zeros(self.num_classes)
        output[i % self.num_classes] = 1
        metadata = dict(uuid=i, gsd=i/10.0)

        return input, output, metadata
    
    def __iter__(self):
        return MyIterator(self)

    @property
    def metadata(self) -> dict[str, Any]:
        return dict(name=self.name)
""";        

In [9]:
class MyDataset():
    def __init__(self, name: str, num_classes: int, num_items: int):
        self.name = name
        self.num_classes = num_classes
        self.num_items = num_items

    def __len__(self) -> int:
        return self.num_items

    def __getitem__(
        self, i: int
    ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
        assert i < self.num_items
        #input = i * torch.ones(3, 32, 32)
        input = torch.rand(3, 32, 32, generator=torch.Generator().manual_seed(i))
        output = torch.zeros(self.num_classes)
        output[i % self.num_classes] = 1
        metadata = dict(uuid=i, gsd=i/10.0)

        return input, output, metadata
    
    @property
    def metadata(self) -> dict[str, Any]:
        return dict(name=self.name)

In [10]:
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)

In [11]:
# NOTE: this would trigger AssertionError
# - i.e., don't seem to be able to iterate over any old thing with __len__ and __getitem__

#for x, y, md in dataset:
#    print(md)

In [12]:
dataset[2]

(tensor([[[0.6147, 0.3810, 0.6371,  ..., 0.2571, 0.0458, 0.1755],
          [0.6177, 0.8291, 0.5246,  ..., 0.0727, 0.6463, 0.9804],
          [0.9441, 0.4921, 0.6659,  ..., 0.5409, 0.7992, 0.7677],
          ...,
          [0.8539, 0.6372, 0.7458,  ..., 0.1298, 0.6168, 0.3205],
          [0.2958, 0.9967, 0.1822,  ..., 0.8969, 0.2356, 0.5888],
          [0.0706, 0.0296, 0.8922,  ..., 0.5491, 0.0876, 0.3411]],
 
         [[0.4372, 0.4878, 0.6424,  ..., 0.5184, 0.8872, 0.9632],
          [0.5844, 0.6769, 0.5594,  ..., 0.6569, 0.2506, 0.8598],
          [0.7092, 0.0267, 0.8670,  ..., 0.9183, 0.4187, 0.3030],
          ...,
          [0.1170, 0.9725, 0.6277,  ..., 0.2707, 0.6050, 0.7176],
          [0.6282, 0.6714, 0.4452,  ..., 0.8159, 0.0394, 0.0110],
          [0.8284, 0.1825, 0.2938,  ..., 0.1049, 0.7608, 0.4508]],
 
         [[0.6736, 0.4308, 0.5341,  ..., 0.0463, 0.9253, 0.6669],
          [0.7646, 0.6069, 0.7050,  ..., 0.9626, 0.1037, 0.9806],
          [0.0337, 0.9157, 0.1781,  ...,

## Dataloader

Note: typical end users will probably not need to deal directly with dataloaders.

But it's handy to see how something conforming to MAITE DataLoader protocol would work. And also gives us batches to show things.

In [13]:
# avoid default collate behavior for maps, e.g.,
# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate
# default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
# --> {'A': tensor([  0, 100]), 'B': tensor([  1, 100])}

def collate_fn(batch):
    from torch.utils.data import default_collate
    return (
        default_collate([t[0] for t in batch]), # collate sequence of inputs (into single tensor)
        default_collate([t[1] for t in batch]), # collate sequence of outputs (into single tensor)
        [t[2] for t in batch], # leave as sequence of dicts
    )

In [14]:
# wrap PyTorch DataLoader?
class MyDataLoader:
    def __init__(self, dataset: Dataset, batch_size: int):
        self.dataset = dataset
        self.batch_size = batch_size

        # reason for type ignore is that MAITE `Dataset` doesn't completely match PyTorch `Dataset` or `IterableDataset`
        # - which has `__add__()` method that i don't think is needed by PyTorch dataloader
        self.dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size) # type: ignore

    def __iter__(self):
        return self.dataloader.__iter__()

In [15]:
dataloader: DataLoader = MyDataLoader(dataset, batch_size=4)

In [16]:
xb, yb, mdb = next(iter(dataloader))

# tensor bridging
xb_pt = torch.as_tensor(xb)
yb_pt = torch.as_tensor(yb)

xb_pt.shape, yb_pt.shape, mdb

(torch.Size([4, 3, 32, 32]),
 torch.Size([4, 10]),
 [{'uuid': 0, 'gsd': 0.0},
  {'uuid': 1, 'gsd': 0.1},
  {'uuid': 2, 'gsd': 0.2},
  {'uuid': 3, 'gsd': 0.3}])

In [17]:
# send batch through model to look at "probabilities"
model = MyModel("mymodel1", 10, "cpu")
preds = model(xb)
torch.round(preds * 1e3) / 1e3

tensor([[0.0030, 0.0130, 0.0000, 0.5930, 0.0010, 0.0030, 0.3740, 0.0000, 0.0010,
         0.0130],
        [0.0000, 0.0020, 0.0010, 0.0200, 0.0000, 0.0000, 0.9760, 0.0000, 0.0000,
         0.0010],
        [0.0140, 0.0140, 0.0000, 0.1050, 0.0010, 0.0010, 0.1860, 0.0000, 0.0000,
         0.6790],
        [0.0030, 0.6500, 0.0190, 0.1870, 0.0740, 0.0010, 0.0610, 0.0000, 0.0000,
         0.0040]])

## Augmentation

Example of augmentation that:
- changes inputs in way that depends on datum-level metadata
  - expects "gsd" datum-level metadata
- adds augmentation-generated datum-level metadata
  - similar to recording particular rotation angle applied
- doesn't change targets/labels since doesn't change image geometry
  - and for image classification, the targets/labels don't have geometry

In [18]:
class MyAugmentation:
    def __init__(self, name: str, multiplier: float):
        self.name = name
        self.multiplier = multiplier

    def __call__(
        self,
        batch: tuple[InputBatchType, OutputBatchType, MetadataBatchType],
    ) -> tuple[torch.Tensor, OutputBatchType, MetadataBatchType]:
        xb, yb, metadata = batch

        # tensor bridging
        xb = torch.as_tensor(xb)
        yb = torch.as_tensor(yb)
        assert xb.ndim == 4

        # iterate over (parallel) elements in batch
        x_augs = [] # list of individual augmented inputs
        md_augs = [] # list of individual image-level metadata

        for x, y, md in zip(xb, yb, metadata):
            assert "gsd" in md
            gsd = md["gsd"]

            # "augment" by changing the input (with seed based on original input)
            seed = int(gsd * xb.sum().item())
            x_aug = torch.rand(3, 32, 32, generator=torch.Generator().manual_seed(seed))

            # replace small slice of image with gsd value so can see a change
            x_aug[0, 0, 0] = self.multiplier * gsd
            x_aug[1, 0, 0] = self.multiplier * gsd
            x_aug[2, 0, 0] = self.multiplier * gsd

            x_augs.append(x_aug)

            # generate fake metadata

            # save original first
            md_aug = copy.deepcopy(md)
            
            # add new metadata under a "namespace" key to avoid collisions (and also for organization)
            aug_parent_key = self.name
            md_aug[aug_parent_key] = {
                "aug_param": self.multiplier * gsd
            }
            md_augs.append(md_aug)
        
        # return batch of augmented inputs, original outputs (since unchanged), and updated metadata
        return torch.stack(x_augs), yb, md_augs

    @property
    def metadata(self) -> dict[str, Any]:
        return dict(
            name=self.name,
            multiplier=self.multiplier
        )

In [19]:
# test typing
augmentation: Augmentation = MyAugmentation("aug1", 10.0)

In [20]:
# show application to batch of data
batch = next(iter(dataloader))
xb_aug, yb_aug, mdb_aug = augmentation(batch)

# tensor bridging
xb_aug = torch.as_tensor(xb_aug)
yb_aug = torch.as_tensor(yb_aug)

# inspect augmented batch
xb_aug.shape, yb_aug.shape, mdb_aug

(torch.Size([4, 3, 32, 32]),
 torch.Size([4, 10]),
 [{'uuid': 0, 'gsd': 0.0, 'aug1': {'aug_param': 0.0}},
  {'uuid': 1, 'gsd': 0.1, 'aug1': {'aug_param': 1.0}},
  {'uuid': 2, 'gsd': 0.2, 'aug1': {'aug_param': 2.0}},
  {'uuid': 3, 'gsd': 0.3, 'aug1': {'aug_param': 3.0}}])

In [21]:
# upperleft element of each channel for image 0 should be 0 (based on fake augmentation)
xb_aug = torch.as_tensor(xb_aug)
xb_aug[0]

tensor([[[0.0000, 0.7682, 0.0885,  ..., 0.3734, 0.3051, 0.9320],
         [0.1759, 0.2698, 0.1507,  ..., 0.7011, 0.2038, 0.6511],
         [0.7745, 0.4369, 0.5191,  ..., 0.6870, 0.0051, 0.1757],
         ...,
         [0.8787, 0.6569, 0.9944,  ..., 0.2269, 0.6664, 0.5225],
         [0.1427, 0.6076, 0.9553,  ..., 0.7924, 0.5431, 0.8903],
         [0.5937, 0.3392, 0.8387,  ..., 0.1876, 0.2099, 0.7210]],

        [[0.0000, 0.0278, 0.2117,  ..., 0.8647, 0.0605, 0.4548],
         [0.9106, 0.6936, 0.9212,  ..., 0.8531, 0.7173, 0.4575],
         [0.4692, 0.1864, 0.3191,  ..., 0.1398, 0.0620, 0.3074],
         ...,
         [0.9365, 0.3450, 0.3035,  ..., 0.2062, 0.6444, 0.6147],
         [0.7693, 0.4257, 0.7569,  ..., 0.6605, 0.8492, 0.5603],
         [0.4499, 0.8180, 0.1410,  ..., 0.5875, 0.8263, 0.2909]],

        [[0.0000, 0.3556, 0.1764,  ..., 0.2433, 0.6071, 0.2682],
         [0.3052, 0.1653, 0.0830,  ..., 0.8141, 0.5898, 0.3632],
         [0.5211, 0.9456, 0.5542,  ..., 0.9786, 0.0670, 0.

In [22]:
# hopefully model outputs are at least a little different
model = MyModel("mymodel1", 10, "cpu")
preds = model(xb_aug)
torch.round(preds * 1e3) / 1e3

tensor([[0.0030, 0.0120, 0.0000, 0.6330, 0.0010, 0.0030, 0.3360, 0.0000, 0.0000,
         0.0110],
        [0.0050, 0.0520, 0.0010, 0.5740, 0.0020, 0.0050, 0.3560, 0.0000, 0.0000,
         0.0050],
        [0.0140, 0.1460, 0.0010, 0.6380, 0.0070, 0.0120, 0.0820, 0.0000, 0.0000,
         0.1000],
        [0.0030, 0.0020, 0.0000, 0.0030, 0.0000, 0.0000, 0.9900, 0.0000, 0.0000,
         0.0020]])

## Metric

In [23]:
# weird metric that sees if score of correct class is >= threshold

class MyMetric:
    def __init__(self, threshold: float = 0.5):
        self.threshold = threshold
        self.total = 0
        self.correct = 0

    def reset(self) -> None:
        self.total = 0
        self.correct = 0

    def update(self, preds: OutputBatchType, targets: OutputBatchType) -> None:
        # tensor bridging
        preds = torch.as_tensor(preds)
        targets = torch.as_tensor(targets)
        
        # actual accuracy
        #self.total += len(preds)
        #self.correct += (preds.argmax(dim=1) == targets.argmax(dim=1)).sum().item()

        # weird accuracy
        self.total += len(preds)

        target_cols = targets.argmax(dim=1)
        for i in range(len(preds)):
            self.correct += 1 if preds[i, target_cols[i]].item() >= self.threshold else 0

    def compute(self) -> dict[str, Any]:
        return {"threshold_exceedance": self.correct / self.total}

In [24]:
metric: Metric = MyMetric()

## Predict

In [25]:
from maite.workflows import predict

In [26]:
# test predict with dataloader and no augmentation

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
dataloader: DataLoader = MyDataLoader(dataset, batch_size=4)

pred_batches, data_batches = predict(
    model=model,
    dataloader=dataloader,
    augmentation=None

)

# look at first batch
(x, y, md) = data_batches[0]
pred = pred_batches[0]

# tensor bridging
x = torch.as_tensor(x)
y = torch.as_tensor(y)
pred = torch.as_tensor(pred)

x.shape, y.shape, pred.shape, md

(torch.Size([4, 3, 32, 32]),
 torch.Size([4, 10]),
 torch.Size([4, 10]),
 [{'uuid': 0, 'gsd': 0.0},
  {'uuid': 1, 'gsd': 0.1},
  {'uuid': 2, 'gsd': 0.2},
  {'uuid': 3, 'gsd': 0.3}])

In [27]:
# show all batch sizes
for (x, y, md), pred in zip(data_batches, pred_batches):
    pred = torch.as_tensor(pred)
    print(f"{pred.shape = }")

pred.shape = torch.Size([4, 10])
pred.shape = torch.Size([4, 10])
pred.shape = torch.Size([4, 10])
pred.shape = torch.Size([2, 10])


In [28]:
# test predict with augmentation

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
dataloader: DataLoader = MyDataLoader(dataset, batch_size=4)
augmentation: Augmentation = MyAugmentation("aug1", 10.0)

pred_batches, data_batches = predict(
    model=model,
    dataloader=dataloader,
    augmentation=augmentation
)

# look at first batch
(x, y, md), pred = data_batches[0], pred_batches[0]

# tensor bridging
x = torch.as_tensor(x)
y = torch.as_tensor(y)
pred = torch.as_tensor(pred)

x.shape, y.shape, pred.shape, md

(torch.Size([4, 3, 32, 32]),
 torch.Size([4, 10]),
 torch.Size([4, 10]),
 [{'uuid': 0, 'gsd': 0.0, 'aug1': {'aug_param': 0.0}},
  {'uuid': 1, 'gsd': 0.1, 'aug1': {'aug_param': 1.0}},
  {'uuid': 2, 'gsd': 0.2, 'aug1': {'aug_param': 2.0}},
  {'uuid': 3, 'gsd': 0.3, 'aug1': {'aug_param': 3.0}}])

In [29]:
# test predict with dataset and no augmentation
# - will use batch_size of 1

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)

pred_batches, data_batches = predict(
    model=model,
    dataset=dataset
)

# look at first batch
(x, y, md) = data_batches[0]
pred = pred_batches[0]

# tensor bridging
x = torch.as_tensor(x)
y = torch.as_tensor(y)
pred = torch.as_tensor(pred)

x.shape, y.shape, pred.shape, md

(torch.Size([1, 3, 32, 32]),
 torch.Size([1, 10]),
 torch.Size([1, 10]),
 [{'uuid': 0, 'gsd': 0.0}])

In [30]:
# show all batch sizes
for (x, y, md), pred in zip(data_batches, pred_batches):
    pred = torch.as_tensor(pred)
    print(f"{pred.shape = }")

pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])
pred.shape = torch.Size([1, 10])


## Evaluate

In [31]:
from maite.workflows import evaluate

In [32]:
# test evaluate with dataloader, no augmentation

threshold = 0.002

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
dataloader: DataLoader = MyDataLoader(dataset, batch_size=4)
metric: Metric = MyMetric(threshold=threshold)

results, _, _ = evaluate(
    model=model,
    dataloader=dataloader,
    metric=metric
)

results

{'threshold_exceedance': 0.6428571428571429}

In [33]:
# test evaluate with dataset, no augmentation

THRESHOLD = 0.002
model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
metric: Metric = MyMetric(threshold=THRESHOLD)

results, _, _ = evaluate(model=model, dataloader=dataloader, metric=metric)

results

{'threshold_exceedance': 0.6428571428571429}

In [34]:
# test evaluate with dataset, different batch size, no augmentation

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
metric: Metric = MyMetric(threshold=threshold)

results, _, _ = evaluate(
    model=model,
    dataset=dataset,
    batch_size=8,
    metric=metric
)

results

{'threshold_exceedance': 0.6428571428571429}

In [35]:
# test evaluate with dataloader, augmentation

model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)
dataloader: DataLoader = MyDataLoader(dataset, batch_size=4)
metric: Metric = MyMetric(threshold=threshold)
augmentation: Augmentation = MyAugmentation("aug1", 10.0)

results, _, _ = evaluate(
    model=model,
    dataloader=dataloader,
    metric=metric,
    augmentation=augmentation
)

results

{'threshold_exceedance': 0.5714285714285714}

## Evaluate that takes predict output

In [36]:
def evaluate_preds(
    preds: Sequence[OutputBatchType],
    targets: Sequence[OutputBatchType],
    metric: Metric,
) -> dict[str, Any]:

    metric.reset()
    for predb, yb, in zip(preds, targets):
        metric.update(predb, yb)
    results = metric.compute()

    return results

In [37]:
# test evaluate that takes predict output

# get predictions (no augmentation)
model = MyModel("mymodel1", 10, "cpu")
dataset: Dataset = MyDataset("pseudo-cifar-10", num_classes=10, num_items=14)

pred_batches, data_batches = predict(
    model=model,
    dataset=dataset,
    batch_size=4
)

# get predictions and targets/labels out of sequence of tuples
preds = [pred for pred in pred_batches]
targets = [t[1] for t in data_batches]

# send through evaluate
result = evaluate_preds(
    preds=preds,
    targets=targets,
    metric=metric
)

result

{'threshold_exceedance': 0.6428571428571429}

## Ideas

- make version of `evaluate` that takes dataset, dataloader, OR, preds/targets?
- OR
- put overloaded versions of `predict` and `evaluate` on a class?
  - that's also generic so that can easily get image classification and object detection flavors of it?
- allow `predict` to return sequence of tuples of non-batch types? (e.g., if batch_size 0 or 1)
- return namedtuple from `predict` if that still conforms to protocol (so that more descriptive)?
- should we have aliases for Target and Prediction to make clearer?
  - e.g., `TargetBatchType`, `PredictionBatchType`
- could make version of `evaluate` that takes predictions literally take output from `predict` (i.e., sequence of tuples)