```
Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY<br/>
Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).<br/>
SPDX-License-Identifier: MIT
```

# Overview of MAITE Protocols

MAITE provides protocols for the following AI components:

* models
* datasets
* dataloaders
* augmentations
* metrics

 MAITE protocols specify expected interfaces of these components (i.e, a minimal set of required attributes, methods, and method type signatures) to promote interoperability in test and evaluation (T&E). This enables the creation of higher-level workflows (e.g., an `evaluate` utility) that can interact with any components that conform to the protocols.

## 1 Concept: Bridging ArrayLikes

MAITE defines a protocol called `ArrayLike` (inspired by NumPy's [interoperability approach](https://numpy.org/devdocs/user/basics.interoperability.html)) that helps components that natively use different flavors of tensors (e.g., NumPy ndarray, PyTorch Tensor, JAX ndarray) work together.

In this example, the functions "type narrow" from `ArrayLike` to the type they want to work with internally. Note that this doesn't necessarily require a conversion depending on the actual input type.

In [None]:
import numpy as np
import torch

from maite.protocols import ArrayLike

def my_numpy_fn(x: ArrayLike) -> np.ndarray:
    arr = np.asarray(x)
    # ...
    return arr

def my_torch_fn(x: ArrayLike) -> torch.Tensor:
    tensor = torch.as_tensor(x)
    # ...
    return tensor

# can apply NumPy function to PyTorch Tensor
np_out = my_numpy_fn(torch.rand(2, 3))

# can apply PyTorch function to NumPy array
torch_out = my_torch_fn(np.random.rand(2, 3))

# note: no performance hit from conversion when all `ArrayLike`s are from same library
# or when can share the same underlying memory
torch_out = my_torch_fn(torch.rand(2, 3))

By using bridging, we MAITE can permit implementers of the protocol to internally interact with their own types while exposing a more open interface to other MAITE-compliant components.

## 2 Data Types

MAITE represents an *individual* data item as a tuple of:

* input (i.e., image),
* target (i.e., label), and
* metadata (at the datum level)

and a *batch* of data items as a tuple of:

* input batches,
* target batches, and
* metadata batches.

MAITE provides versions of `Model`, `Dataset`, `DataLoader`, `Augmentation`, and `Metric` protocols that correspond to different machine learning tasks (e.g. image classification, object detection) by parameterizing protocol interfaces on the particular input, target, and metadata types associated with that task.

### 2.1 Image Classification

For image classification with `Cl` image classes, we have:

```python
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
    id: str|int

InputType: TypeAlias = ArrayLike  # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ArrayLike  # shape-(Cl) tensor of one-hot encoded true class or predicted probabilities

InputBatchType: TypeAlias = Sequence[ArrayLike]  # element shape-(C, H, W) tensor of N images
TargetBatchType: TypeAlias = Sequence[ArrayLike]  # element shape-(Cl,)
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
```

Notes:

* `TargetType` is used for both ground truth (coming from a dataset) and predictions (output from a model). So for a problem with 4 classes,

  * true label of class 2 would be one-hot encoded as `[0, 0, 1, 0]`
  * prediction from a model would be a vector of pseudo-probabilities, e.g., `[0.1, 0.0, 0.7, 0.2]`
* `InputType` and `InputBatchType` are shown with shapes following PyTorch channels-first convention

These type aliases along with the versions of the various component protocols that use these types can be imported from `maite.protocols.image_classification` (if necessary):

In [None]:
# import protocol classes
from maite.protocols.image_classification import (
    Dataset,
    DataLoader,
    Model,
    Augmentation,
    Metric
)

# import type aliases
from maite.protocols.image_classification import (
    InputType,
    TargetType,
    DatumMetadataType,
    InputBatchType,
    TargetBatchType,
    DatumMetadataBatchType
)

Alternatively, image classification components and types can be accessed via the module directly:

In [None]:
import maite.protocols.image_classification as ic

# model: ic.Model = load_model(...)

### 2.2 Object Detection

For object detection with `D_i` detections in an image `i`, we have:

```python
# define type to store an id of each datum (additional fields can be added by defining structurally-assignable TypedDict)
DatumMetadataType(TypedDict):
    id: str|int

class ObjectDetectionTarget(Protocol):
    @property 
    def boxes(self) -> ArrayLike: ...  # shape-(D_i, 4) tensor of bounding boxes w/format X0, Y0, X1, Y1

    @property
    def labels(self) -> ArrayLike: ... # shape-(D_i) tensor of labels for each box

    @property
    def scores(self) -> ArrayLike: ... # shape-(D_i) tensor of scores for each box (e.g., probabilities)

InputType: TypeAlias = ArrayLike  # shape-(C, H, W) tensor with single image
TargetType: TypeAlias = ObjectDetectionTarget

InputBatchType: TypeAlias = Sequence[ArrayLike]  # sequence of N ArrayLikes each of shape (C, H, W)
TargetBatchType: TypeAlias = Sequence[TargetType]  # sequence of object detection "target" objects
DatumMetadataBatchType: TypeAlias = Sequence[DatumMetadataType]
```

Notes:

* `ObjectDetectionTarget` contains a single label and score per box
* `InputType` and `InputBatchType` are shown with shapes following PyTorch channels-first convention

## 3 Models

All models implement a `__call__` method that takes the `InputBatchType` and produces the `TargetBatchType` appropriate for the given machine learning task.

In [None]:
import maite.protocols.image_classification as ic
print(ic.Model.__doc__)

In [None]:
import maite.protocols.object_detection as od
print(od.Model.__doc__)

## 4 Datasets and DataLoaders

`Dataset`s provide access to single data items and `DataLoader`s  provide access to batches of data with the input, target, and metadata types corresponding to the given machine learning task.

In [None]:
print(ic.Dataset.__doc__)

In [None]:
print(ic.DataLoader.__doc__)

In [None]:
print(od.DataLoader.__doc__)

## 5 Augmentations

`Augmentation`s take in and return a batch of data with the `InputBatchType`, `TargetBatchType`, and `DatumMetadataBatchType` types corresponding to the given machine learning task.

Augmentations can access the datum-level metadata associated with each data item to potentially tailor the augmentation to individual items. Augmentations can also associate new datum-level metadata with each data item, e.g., documenting aspects of the actual change that was applied (e.g., the actual rotation angle sampled from a range of possible angles).

In [None]:
print(ic.Augmentation.__doc__)

In [None]:
print(od.Augmentation.__doc__)

## 6 Metrics

The `Metric` protocol is inspired by the design of existing libraries like Torchmetrics and Torcheval. The `update` method operates on batches of predictions and truth labels by either caching them for later computation of the metric (via `compute`) or updating sufficient statistics in an online fashion.

In [None]:
print(ic.Metric.__doc__)

In [None]:
print(od.Metric.__doc__)

## 7 Workflows

MAITE provides high-level utilities for common workflows such as `evaluate` and `predict`. They can be called with either `Dataset`s or `DataLoader`s, and with optional `Augmentation`.

The `evaluate` function can optionally return the model predictions and (potentially-augmented) data batches used during inference.

The `predict` function returns the model predictions and (potentially-augmented) data batches used during inference, essentially calling `evaluate` with a dummy metric.

In [None]:
from maite.workflows import evaluate, predict
# we can also import from object_detection module
# where the function call signature is the same

In [None]:
print(evaluate.__doc__)

In [None]:
print(predict.__doc__)