# Overview of Protocols

> Justin Goodwin, MIT Lincoln Laboratory

> Feb. 14, 2023


## Why Type Annotations and Protocols?

Python type annotations and protocols in API design helps to improve code quality, readability, and collaboration, while also enabling better tooling and more flexible code.

  - Improved readability: Type annotations provide clear and concise documentation for what each function or method is supposed to receive as inputs and return as outputs. This makes the code more readable and easier to understand, especially for developers who are new to the project.
  - Enhanced code quality: Type annotations help catch errors early in the development process, making it easier to detect and fix bugs. This can result in improved code quality and reduced maintenance costs over time.
  - Better tooling: Type annotations enable powerful tools such as type checkers, linters, and IDEs to provide more accurate and helpful insights into the code. For example, with type annotations, a type checker can verify that a function is being passed the correct types of arguments and that it is returning the expected type of output.
  - Improved collaboration: Type annotations serve as a shared understanding between developers about the expected inputs and outputs of functions and methods. This helps to ensure that everyone is on the same page, reducing the chance of misunderstandings and increasing collaboration efficiency.
  - Protocols: Python's support for protocols allows for the creation of highly flexible and reusable code. Protocols allow you to define a set of methods and properties that a class must implement, without specifying a particular implementation. This means that different classes can implement the same protocol in different ways, making it easier to write reusable code that works with a variety of types.


## Summary of Review
This review will highlight the protocols in the `jatic_toolbox.protocols` package, which have been developed to:

  1. Specify objects that can be converted into arrays: `ArrayLike` and `ShapedArray`.
  2. Define collections of homogeneous objects: `TypedCollection`.
  3. Provide a standard interface for augmenting images, videos, bounding boxes, keypoints, etc.: `Augmentations`.
  4. Establish a standard interface for machine learning models: `Model`, which includes `Classifier` and `ObjectDetection`.
     - Define the output of ML models: `ModelOutput` with specific interfaces for logits, probabilities, and object detections: `HasLogits`, `HasProbs`, and `HasObjectDetections`, respectively.

## Getting The Most from this Notebook

To maximize the benefits of this notebook, it is recommended to use an Integrated Development Environment (IDE) such as Visual Studio Code (VSCode). By using an IDE, you can take advantage of static type checking. Specifically, if you load the notebook within VSCode, the IDE will utilize Pyright within the Jupyter Notebook.

This combination of VSCode and Pyright will provide many benefits, such as catching type-related errors early in the development process and helping you write more efficient and effective code. Additionally, with the use of an IDE, you will have access to a wide range of tools and features designed to make your programming experience smoother and more streamlined. Overall, using VSCode with Pyright within a Jupyter Notebook is the best way to utilize this notebook and achieve the best possible results.

## Table of Contents

1. [Protocols](#Protocols)
2. [Arrays](#arrays)
3. [TypedCollction](#typed-collection)
4. [Augmentation](#augmentation)
5. [Model](#model)

## Install and Import

In [1]:
# no additional packages needed for this
#!pip install jatic_toolbox
#!pip install torch --extra-index-url https://download.pytorch.org/whl/cu116


In [2]:
from typing import Tuple, List, NamedTuple, Sequence
from typing_extensions import Protocol, reveal_type

import numpy as np
from numpy.typing import NDArray

## Protocols

### Motivation
Protocols are a recent addition to the Python language, first introduced in version 3.8 (however, backports exist for version 3.6+). They are a way to define a set of behaviors that a Python object should have, without specifying a concrete class hierarchy or using mixins. A key feature of Python is its "duck-typed" capability, meaning that "if it walks like a duck and quacks like a duck, then we call it a duck".

Protocols embrace this philosophy by allowing objects to be checked for compatibility based on the methods they have, rather than their inheritance. For example, instead of requiring an object to inherit from a specific class or mixin in order to be considered "array-like", you can define a protocol that specifies what it means for an object to be "array-like", such as being able to be manipulated as a Numpy array or PyTorch tensor. If an object has the methods and attributes required by the protocol, then it can be considered "array-like", regardless of its class hierarchy.

Protocols offer several benefits over traditional inheritance-based approaches to defining behaviors in Python. They remove the need for frameworks to broadcast concrete class hierarchies and mixins, and they allow for static type checking and limited runtime type checking based on structures, not inheritance. This can lead to more flexible and maintainable code, as well as improved type checking and error catching.

### Example: Designing Interface with Concrete Types vs. Protocols



In [3]:
# THIRD PARTY and Mixins
from abc import ABC, abstractmethod

class DetectMixin(ABC):
    @ abstractmethod
    def detect(self, x: NDArray) -> List[NDArray]:
        raise NotImplementedError()

def evaluate_concrete_detector(any_detector: DetectMixin):
    ...


# import third_party import ThirdPartyDetector
class ThirdPartyDetector:
    def detect(self, x: NDArray) -> List[NDArray]:
        ...

# type-checker error: `third_party.detector` not 
#                     subclass of `DetectMixin`
evaluate_concrete_detector(ThirdPartyDetector())  

# Create class wrapper to appease interface

class DetectWrapper(DetectMixin):
    def __init__(self, third_party_detector) -> None:
        self.detector = third_party_detector

    def detect(self, x: NDArray) -> List[NDArray]:
        return self.detector.detect(x)

wrapped_detector = DetectWrapper(ThirdPartyDetector())

# type-checker OK (but at what cost!)
evaluate_concrete_detector(wrapped_detector)  

In [4]:
class SupportsDetection(Protocol):
    def detect(self, x: NDArray) -> List[NDArray]:
        ...

def evaluate_detector(any_detector: SupportsDetection):
    ...


# import third_party import ThirdPartyDetector
class ThirdPartyDetector:
    def detect(self, x: NDArray) -> List[NDArray]:
        ...

# type-check OK if `third_party.detector` has 
#  `.detect(self, x: Array) -> list[Array]`
evaluate_detector(ThirdPartyDetector())

## Arrays

### Motivation

An `ArrayLike` defines a common interface for objects that can be manipulated as arrays, regardless of the specific implementation. This allows code to be written in a more generic way, allowing it to work with different array-like objects without having to worry about the details of the specific implementation. With an `ArrayLike` protocol, vendors can write functions and algorithms that operate on arrays without JATIC defining the specific implementation of arrays to use. 

This will improve code readability and maintainability, as well as make it easier to switch to different array implementations if needed. For example, vendors can write functions that takes an `ArrayLike` object as input and perform some mathematical operation on the elements. This function would work with any object that satisfies the `ArrayLike` protocol, such as a numpy `ndarray`, a PyTorch `tensor`, or a custom object that implements the same methods and attributes as the `ArrayLike` protocol. In addition, an `ArrayLike` protocol is useful for providing type hints and improving code safety, as it can be used in conjunction with a static type checker to ensure that the correct types of objects are being passed as arguments. This can help catch errors before they cause problems at runtime.


In [5]:
from PIL import Image
import torch as tr
import numpy as np
from jatic_toolbox.protocols import ArrayLike

# a function that takes ArrayLike as an input
def f(data: ArrayLike):
    ...

# type check errors: cannot be assigned to parameter "data" of type "ArrayLike" in function "f"
f(1)
f(True) 
f(1 + 1j)
f(dict(x=1))
f([dict(x=1)]) 
f(["string"])
f([1 + 1j])
f((1,))
f([True, False])

class MyBadArray:
    def not_array(self) -> List[int]: ...

f(MyBadArray())

# passes
f(np.asarray([1., 2.]))
f(np.zeros(2))
f(tr.zeros(2))

class MyArray:
    def __array__(self) -> List[int]: ...

f(MyArray())

# fails for pillow array
x: Image.Image = Image.fromarray(np.zeros((10, 10, 3)).astype(np.uint8))
f(x)

In [6]:
from jatic_toolbox.protocols import ImageType

# a function that takes ArrayLike as an input
def g(data: ImageType):
    ...

# type check errors: cannot be assigned to parameter "data" of type "ArrayLike" in function "f"
g(1)
g(True) 
g(1 + 1j)
g(dict(x=1))
g([dict(x=1)]) 
g(["string"])
g([1 + 1j])
g((1,))
g([True, False])

# passes
g(np.asarray([1., 2.]))
g(np.zeros(2))
g(tr.zeros(2))

# passes for pillow array
x: Image.Image = Image.fromarray(np.zeros((10, 10, 3)).astype(np.uint8))
g(x)



In [7]:
from jatic_toolbox.protocols import ShapedArray

# a function that takes ArrayLike as an input
def h(data: ShapedArray):
    ...

# type check errors: cannot be assigned to parameter "data" of type "ArrayLike" in function "f"
h(1)
h(True) 
h(1 + 1j)
h(dict(x=1))
h([dict(x=1)]) 
h("string")  
h((1,))
h([1., 2.])
h([True, False])
h([1 + 1j])

class HasShape:
    @property
    def shape(self) -> Tuple[int,...]: ...

h(HasShape())

# passing
f(np.zeros(2))

class MyArray:
    def __array__(self) -> List[int]: ...
    @property
    def shape(self) -> Tuple[int,...]: ...

h(MyArray())


## Typed Collection

A `TypedCollection` is a homogeneous collection of Python objects that can be used to define a consistent type in an interface. For example, if the inputs and outputs of the interface are expected to be a NumPy `NDArray` or PyTorch `Tensor`, a `TypedCollection` type can be used to define this consistency.

Consider the following example where the function `foo` only accepts inputs as individual float values:

```python
def foo(*inputs: float) -> Tuple[float, ...]
   ...
```

However, what if we want the function to accept inputs in the form of a list or a dictionary of float values? This can be achieved by defining a `TypedCollection` type alias as follows:


```python
TypedCollection: TypeAlias = Union[T, Sequence[T], Mapping[Any, T], Mapping[Any, Sequence[T]]]
```

With this definition, we can redefine the function `foo` to accept inputs in a nested collection form:

```python
def foo_nested(*inputs: TypedCollection[float]) -> TypedCollection[float]
   ...
```

Now, the function `foo_nested` can accept inputs in the form of individual float values, a list of float values, or a dictionary with float values as its values. This makes the interface more flexible and easier to use. For example, calling `foo(1.0, [1.0, 2.0])` would fail type checking, whereas `foo_nested(1.0, [1.0, 2.0])` would pass.  See more examples below for passing and failing interfaces.

In [8]:
from jatic_toolbox.protocols import TypedCollection

# on it's own it's just a simple type
def f(data: TypedCollection[float]):
    ...

# failing type check
f("string")
f(["string"])
f(("string",))
f(dict(x="string"))

f(1+1j)
f([1+1j])
f((1+1j,))
f(dict(x=1+1j))

# passing
f(1.0)
f([1.0])
f((1.0,))
f(dict(x=1.0))

In [9]:
def f(*inputs: TypedCollection[float]):
    ...      

# failing type check
f(1.0, "string")
f(1.0, ["string"])
f([1.0], ("string",))
f(1.0, dict(x="string"))

# passing
f(1.0, 1.0)
f([1.0], dict(x=1.0))
f((1.0,), [1.0, 2.0], 1.0)
f(dict(x=1.0))
f(dict(x=[1., 2.0]))

### Example Use Case with `TypedCollection`

In this explanation, how to use `TypedCollection` with PyTrees to support variable number of 
arguments to a python function (variadic arguments). There are multiple variations of this concept, including:

- [PyTorch](https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py) 
- [JAX PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html)
- [Optimized PyTrees](https://github.com/metaopt/optree)

For the purpose of this explanation, we will be focusing on PyTorch's implementation of `PyTree`. It's worth noting that none of these implementations require special tensors or containers to function and are strictly python only. The two functions that are most relevant to this example are `torch.utils._pytree.tree_flatten` and `torch.utils._pytree.tree_unflatten`. For demonstration purposes, we'll be using simple lists in the following example:


In [10]:
from torch.utils._pytree import tree_flatten, tree_unflatten

input1: TypedCollection[int] = [1, 2, 3]
input2: TypedCollection[int] = (4, 5, 6)
input3: TypedCollection[int] = dict(a=[7,8,9])


flat_inputs, tree_spec = tree_flatten((input1, input2, input3))
# flat_inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9]
flat_inputs


[1, 2, 3, 4, 5, 6, 7, 8, 9]

In [11]:
# add 10 to each value
augment = [x + 10 for x in flat_inputs]

# move back to original
aug_inputs = tree_unflatten(augment, tree_spec)
# aug_inputs = ([11, 12, 13], (14, 15, 16), {"a": [17, 18, 19])
aug_inputs

([11, 12, 13], (14, 15, 16), {'a': [17, 18, 19]})

Here we put it all together using a simple interface example:

In [12]:
def foo(*inputs: TypedCollection[int]) -> Tuple[TypedCollection[int], ...]:
    flat_inputs, tree_spec = tree_flatten(inputs)
    augment = [x + 10 for x in flat_inputs]
    return tree_unflatten(augment, tree_spec)

aug_input1, aug_input2, aug_input3 = foo(input1, input2, input3)

Now lets demonstrate wiht NumPy arrays instead of integers:

In [13]:
x = np.zeros((2, 2))
y = np.ones((2, 2))

def foo_numpy(*inputs: TypedCollection[np.ndarray]) -> Tuple[TypedCollection[np.ndarray], ...]:
    flat_inputs, tree_spec = tree_flatten(inputs)
    augment = [x + 10 for x in flat_inputs]
    return tree_unflatten(augment, tree_spec)

aug_input1, aug_input2 = foo_numpy(x, dict(a=y))


# flat_inputs
# [array([[0., 0.],
#        [0., 0.]]),
# array([[1., 1.],
#        [1., 1.]])]

# aug_inputs
# (array([[10., 10.],
#        [10., 10.]]),
# {'a': array([[11., 11.],
#             [11., 11.]])})

It should be evident how this concept is beneficial for common ML tasks in object detection and segmentation, where the same augmentations are often required to be applied to collections of input and target variables. 


## Augmentation

A data augmentation interface should allow for applying transformations to a uniform set of objects with a shared type, such as PyTorch Tensors. This provides a versatile API to perform a variety of augmentations on different types of data, including images, sequences, videos, bounding boxes, and segmentation masks. The flexibility of the interface makes it easy to use different types of data and modify the processing pipeline as needed, giving developers and users more control over the data processing pipeline. The interface eliminates the need to explicitly define inputs, allowing for customization of transformations across different combinations of collections, such as images and bounding boxes.

The data augmentation interface should also provide a clear way to set the random number generator (RNG) state, without relying on a global random state. This allows for greater control over the RNG and more reproducible results, as described in a [NumPy post on RNG best practices](https://albertcthomas.github.io/good-practices-random-number-generators/). By explicitly supporting this feature, the augmentation pipelines will be isolated from global entropy, ensuring that results are consistent and can be easily reproduced, especially during testing and evaluation.


### Protocol Interface Overview

The Augmentation protocol is defined using the 'TypedCollection' as follows:

```python
class Augmentation(Protocol[T]):
    def __call__(
        self, *inputs: TypedCollection[T], rng: Optional[RandomStates] = None,
    ) -> Union[TypedCollection[T], Tuple[TypedCollection[T], ...]]:
    ...
```

This protocol allows for a variety of augmentations, including:

```python
def aug1(data: Tensor) -> Tensor:
    # Augment a single data tensor and return the augmented tensor.
    ...

def aug2(data: Tensor, boxes: Tensor) -> Tuple[Tensor, Tensor]:
    # Augment data and boxes and output augmented data and boxes.
    ...

def aug2(data: Sequence[Tensor], mapped_data: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Dict[str, Tensor]]:
    # Augment data that is a sequence (e.g, video) and a dictionary of tensor values.
    # The return value has the data and mapped_data augmented.
    ...
```

Note that due to limitations in handling variadic arguments for Python type hinting, none of the above functions will satisfy the `Augmentation` protocol. Instead, the protocol can only be satisfied by defining an interface using variadic arguments. An example of this is shown below.


In [14]:
from typing import Optional
from jatic_toolbox.protocols import Augmentation

def auger(x: Augmentation[np.ndarray]):
    ...


def A(*inputs: TypedCollection[np.ndarray], rng: Optional[int] = None) -> TypedCollection[np.ndarray]:
    ...


def AA(*inputs: TypedCollection[np.ndarray]) -> TypedCollection[np.ndarray]:
    ...


def B(data: np.ndarray, boxes: np.ndarray, rng: Optional[int] = None) -> TypedCollection[np.ndarray]:
    ...


# passes pyright
auger(A)


# fails pyright
auger(AA)
auger(B)


### Implementation Based on TorchVision V2 Prototype Interface

The torchvision library has recently developed a new prototype interface for its transforms, which can be found in the torchvision/prototype directory on GitHub and is described in a recent blog post ([source](https://github.com/pytorch/vision/tree/main/torchvision/prototype), [blog post](https://pytorch.org/blog/extending-torchvisions-transforms-to-object-detection-segmentation-and-video-tasks/)). This new interface represents a major advancement by supporting transforms not only on images but also on videos, bounding boxes, labels, and segmentation masks. It accomplishes this by introducing a Datapoint class, a subtype of Tensor, to describe various types of data. The interface also provides a flexible API for mapping transformations over collections of tensors, making it easier to incorporate various features in each transformation. The Transform class provides a simple API for implementing transforms, allowing developers to draw random parameters and transform individual tensors or features. Additionally, the interface continues to support batches of tensors and features, as well as GPU support.

The prototype transform library has not been officially released so we do not contain example interfaces to support `torchvision` V2 in the `jatic_toolbox` yet.


## Example Augmentation Implementaion

Let's consider a more concreate `Augmentation` example of implementing and applying a `RandomCrop` augmentation:

In [15]:
from typing import Any, Dict, List, Tuple
from numpy.random import Generator, default_rng
from torch.utils._pytree import tree_flatten, tree_unflatten

import numpy as np


class RandomCrop(Augmentation):
    def __init__(self, size: Tuple[int, int]):
        """
        Randomly crop the last two dimension of an array (e.g., height and width of image).

        Parameters
        ----------
        size : Tuple[int, int]
            The desired dimensions of the last two dimensions of the array.
        """
        super().__init__()
        self.output_size = size

    def _get_params(
        self, flat_inputs: List[ArrayLike], rng: Generator
    ) -> Dict[str, int]:
        """
        Calculate the parameters of the random crop to be applied for all inputs.

        Parameters
        ----------
        flat_inputs : List[ArrayLike]
            A set of inputs

        rng: Generator
        """
        assert len(flat_inputs) > 0
        h, w = np.asarray(flat_inputs[0]).shape
        th, tw = self.output_size

        dw = 0
        if tw < w:
            dw = rng.integers(0, w - tw)

        dh = 0
        if th < h:
            dh = rng.integers(0, h - th)

        return dict(bottom=dh, top=dh + th, left=dw, right=dw + tw)

    def _transform(self, inpt: np.ndarray, params: Dict[str, int]) -> np.ndarray:
        """Apply the augmentation."""
        return inpt[
            ..., params["bottom"] : params["top"], params["left"] : params["right"]
        ]

    def __call__(
        self, *inputs: TypedCollection[ArrayLike], rng: Optional[Any] = None
    ) -> Tuple[TypedCollection[np.ndarray], ...]:
        if rng is None:
            rng = default_rng()

        # flatten the inputs
        flat_inputs, spec = tree_flatten(inputs)

        # calculate the parameters for cropping all inputs
        params = self._get_params(flat_inputs, rng=rng)

        # apply the augmentation
        flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]

        # return augmented objects in original format
        return tree_unflatten(flat_outputs, spec)


# this should pass pyright
cropper = RandomCrop((2, 1))

a, b = cropper(np.zeros((10, 10)), dict(val=[np.ones((10, 10))]))
reveal_type(a)  # prints: Runtime type is 'ndarray'
reveal_type(b)  # prints: Runtime type is 'dict'

auger(cropper)


Runtime type is 'ndarray'
Runtime type is 'dict'


Here are few examples of executing `RandomCrop`:

In [16]:
import numpy as np
from numpy.random import default_rng

random_crop = RandomCrop((2, 1))

# Example 1: Simple Array
x = np.arange(16).reshape(4, 4)
xout, = random_crop(x, rng=default_rng(0))
print("x:", xout.shape)
# prints "x: (2, 1)"

# Example 2: Multiple Arrays
y = np.arange(100, 120).reshape(1, 5, 4)
xout, yout = random_crop(x, y, rng=default_rng(0))
print("y:", yout.shape)
# prints "y: (1, 2, 1)"

# Example 2: Multiple Arrays
xy, = random_crop([x, y], rng=default_rng(0))

# Example 3: List of Array, Tuple, and Dict
z = dict(val=np.arange(1000, 1032).reshape(1, 2, 4, 4))
xyz, = random_crop([x, y, z], rng=default_rng(0))
print("z:", xyz[2]["val"].shape)
# prints "z: (1, 2, 2, 1)"

x: (2, 1)
y: (1, 2, 1)
z: (1, 2, 2, 1)


## Model

TODO: Description


In [17]:
from jatic_toolbox.protocols import Classifier, ObjectDetector
from torch import Tensor


class LogitsOutput(NamedTuple):
    logits: Tensor


class ProbsOutput(NamedTuple):
    probs: Tensor


class BadClassifierOutput(NamedTuple):
    outputs: ArrayLike


class Detections(NamedTuple):
    scores: Sequence[Sequence[dict]]
    boxes: Tensor
    foo: int


class BadDetections(NamedTuple):
    boxes: ArrayLike


def classifier(data: TypedCollection[Tensor]) -> LogitsOutput:
    ...


def classifier2(data: TypedCollection[Tensor]) -> ProbsOutput:
    ...


def bad_output_classifier(data: ArrayLike) -> BadClassifierOutput:
    ...


def bad_input_classifier(data: Sequence[int]) -> BadClassifierOutput:
    ...


def detector(data: TypedCollection[Tensor]) -> Detections:
    ...


def bad_output_detector(data: TypedCollection[Tensor]) -> BadDetections:
    ...


#####################
# Check function call
#####################

# passes
classifier(np.zeros(2))
classifier(dict(x=np.zeros(2)))
detector(np.zeros(2))
detector(dict(x=np.zeros(2)))

# fails
classifier([1, 2])
detector([1, 2])
classifier("hi")
detector(True)


#################
# Check interface
#################


# test functions
def eval_classifier(f: Classifier[Tensor]):
    ...


def eval_detector(f: ObjectDetector[Tensor]):
    ...


# passes
eval_classifier(classifier)
eval_classifier(classifier2)
eval_detector(detector)


# does not pass
eval_classifier(bad_output_classifier)
eval_classifier(bad_input_classifier)
eval_classifier(detector)
eval_detector(bad_output_detector)
eval_detector(classifier)
