In [None]:
# Run this cell to install DiffeRT and its dependencies, e.g., on Google Colab

try:
    import differt  # noqa: F401
except ImportError:
    import sys  # noqa: F401

    !{sys.executable} -m pip install differt[all] beartype

# Runtime type checking

To avoid common pitfalls with function arguments,
such as using the wrong data type or array shape,
functions in this library are annotated with type hints,
and can be type checked at runtime.

For that, we rely on the [`jaxtyping`](https://pypi.org/project/jaxtyping/)
and [`beartype`](https://pypi.org/project/beartype/) modules.

By default, no type-checking is performed, to avoid an additional overhead when
calling a function.
To enable runtime type checking, you can use {func}`jaxtyping.install_import_hook`.

## Input arguments checking

Let's take the example of the {func}`perpendicular_vectors<differt.geometry.perpendicular_vectors>` function:

In [None]:
import inspect

import jax
import jax.numpy as jnp
from jaxtyping import install_import_hook

with install_import_hook("differt", "beartype.beartype"):
    from differt.geometry import perpendicular_vectors

inspect.signature(perpendicular_vectors)

As we can see, its signature expects an array of 3D vectors as input, and an array of 3D vectors as output,
which matching shapes.

In [None]:
key = jax.random.key(1234)

arr = jax.random.normal(key, (10, 3))
arr

Hence, if we provide an array of 3D vectors as input, everything works just fine:

In [None]:
perpendicular_vectors(arr)

However, if anything else than a 3D vectors is provided, an error will be raised:

In [None]:
arr = jax.random.normal(key, (2, 10, 4))  # 4D vectors
perpendicular_vectors(arr)

The error message is a bit verbose,
but we can see at the end that we expected `Shaped[Array, '*batch 3']`,
and we received `f32[2,10,4]` (i.e., `Float32[Array, "2 10 4"]`).
{class}`Float32<jaxtyping.Float32>` is a subclass of {class}`Shaped<jaxtyping.Shaped>`, 
but `*batch 3` cannot be matched to `2 10 4`, as `4 != 3`. Thus, an error was raised for that reason.

## Output checking

The output values are also checked by the type checker.
If you use one of the functions from our library, you are guaranteed to
have correct output types if you provided valid inputs.

In other words, type checking the outputs should **never fail**.
If you encounter a case where your input is valid, but the returned output is not,
please report it via the [GitHub issues](https://github.com/jeertmans/DiffeRT/issues).

If you define custom function yourself, this is always good to use type
annotations and runtime checking:

In [None]:
from beartype import beartype as typechecker
from jaxtyping import Array, Num, jaxtyped


@jaxtyped(typechecker=typechecker)
def my_custom_transpose(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose().transpose()  # Oops, transposed one too many times


x = jnp.arange(70).reshape(10, 7)
x

In [None]:
my_custom_transpose(x)

Here, the error message tells us that it inferred `m=10` and `n=7` from the input arguments,
but that is does not match the expected output shape, i.e., `(n, m) = (7, 10) != (10, 7)`.

Thanks to the type checker, we rapidly caught the error, and we can fix the function:

In [None]:
@jaxtyped(typechecker=typechecker)
def my_custom_transpose_fixed(x: Num[Array, "m n"]) -> Num[Array, "n m"]:
    return x.transpose()  # Now this is all good


my_custom_transpose_fixed(x)