In [None]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all]

# Runtime type checking

To avoid common pitfalls with function arguments,
such as using the wrong data type or array shape,
most functions in this library are wrapped with a 
runtime type checker that utilizes the type annotations
to determine what inputs and outputs we should expect.

For that, we rely on the [`jaxtyping`](https://pypi.org/project/jaxtyping/)
and [`beartype`](https://pypi.org/project/beartype/) modules.

## Input arguments checking

Let's take the example of the {func}`sorted_array2<differt.utils.sorted_array2>` function:

In [None]:
import inspect

import jax
import jax.numpy as jnp

from differt.utils import sorted_array2

inspect.signature(sorted_array2)

As we can see, its signature expects a 2D array as input, and a 2D array as output,
which matching shapes.

In [None]:
key = jax.random.PRNGKey(1234)

arr = jax.random.randint(key, (10, 4), 0, 2)
arr

Hence, if we provide a 2D array as input, everything works just fine:

In [None]:
sorted_array2(arr)

However, if anything else than a 2D array is provided, an error will be raised:

In [None]:
arr = jax.random.randint(key, (2, 10, 4), 0, 2)  # 3D array
sorted_array2(arr)

The error message is a bit verbose,
but we can see at the end that we expected `Shaped[Array, 'm n']`
and we received `i32[2,10,4]` (i.e., `Int32[Array, "2 10 4"]`).
`Int32` is a subclass of `Shaped`, but `m n` cannot be matched to `2 10 4`, as there is
one extra dimension. Thus, an error was raised for that reason.

## Output checking

The output values are also checked by the type checker.
If you use one of the functions from our library, you are guaranteed to
have correct output types if you provided valid inputs.

In other words, type checking the outputs should **never fail**.
If you encounter a case where your input is valid, but the returned output is not,
please report it via the [GitHub issues](https://github.com/jeertmans/DiffeRT/issues).

If you define custom function yourself, this is always good to use type
annotations and runtime checking:

In [None]:
from beartype import beartype as typechecker
from jaxtyping import Array, Num, jaxtyped


@jaxtyped(typechecker=typechecker)
def my_custom_transpose(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose().transpose()  # Oops, transposed one too many times


x = jnp.arange(70).reshape(10, 7)
x

In [None]:
my_custom_transpose(x)

Here, the error message tells us that it inferred `m=10` and `n=7` from the input arguments,
but that is does not match the expected output shape, i.e., `(n, m) = (7, 10) != (10, 7)`.

Thanks to the type checker, we rapidly caught the error, and we can fix the function:

In [None]:
@jaxtyped(typechecker=typechecker)
def my_custom_transpose_fixed(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose()  # Now this is all good


my_custom_transpose_fixed(x)