Skip to content

Latest commit

 

History

History
116 lines (84 loc) · 3.74 KB

README.md

File metadata and controls

116 lines (84 loc) · 3.74 KB

Arrays typing annotations

API

Function inputs & outputs can be annotated to help the reader better understand intended shape/dtype.

from etils.array_types import Array, FloatArray, f32, ui8


def _normalize_image(img: ui8['h w c']) -> f32['h w c']:
  return np.interp(img, from_=(0, 255), to=(-1, 1))

This indicates the reader that the function takes a 3d uint8 array and return a 3d float32 with the same shape values.

Note: Those typing annotations are not (yet) detected by static type checking tools. However, they are already helpful as documentation.

Annotation conventions

Typing annotations shape follow the conventions:

  • Valid symbols:
    • str: Named axis (e.g. f32['batch height width'])
    • int: Static axis (e.g. f32[28, 28], f32['h w 3'])
    • _: Anonymous axis (e.g. f32['batch _ _ c'], f32[None, 3])
    • ...: Anonymous zeros or more axis (e.g. f32['... h w c'], f32[..., 3])
    • *name: Named zeros or more axis (e.g. f32['*batch_dims h w c'])
    • +, -, /, * operators (e.g. f32['h/2 w/2 c1+c2'])
  • Typing annotations are only considered to be consistent per function call, so a function f32['h w'] -> f32['h w'] can be called twice with 2 different image sizes.
  • Passing multiple values is the same as concatenating the string (e.g. f32[..., 'h', 'w', 3] == f32['... h w 3']
  • DType can be:
    • Array[...]: Any dtype accepted
    • FloatArray (accepts f32, bf16, ...), IntArray (accepts ui8, i32, i64, ...): Respectively accept an union of multiple types
    • f32, ui8, ...: Specific type
  • ArrayLike[f32[...]] indicates any array convertible values are accepted (list, tuple, ...).

Runtime shape/dtype checking

You can decorate your function with @enp.check_and_normalize_arrays so that array shape/dtype are dynamically validated at runtime:

from etils import enp
from etils.array_types import FloatArray, IntArray


@enp.check_and_normalize_arrays
def add(x: IntArray, y: IntArray) -> IntArray:
  return x + y

TF / Jax / Numpy compatibility

Functions decorated with enp.check_and_normalize_arrays support np, jnp, and tnp:

  • If args are mixed between jnp and tnp, an error is raised
  • If args are xnp with np, the np array is auto-casted to xnp.
  • You can force usage of TF / Jax / Numpy by passing a xnp= kwargs (automatically added).
add(np.array(1), jnp.array(2))  # np auto-casted to jnp
add(tf.constant(1), jnp.array(2))  # Error jnp / TF conflict
add(tf.constant(1), jnp.array(2), xnp=jnp)  # Force jnp usage

Using strict=False makes your function auto-convert list, int,... to xnp.ndarray:

@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray):
  return x + y

add([1, 2, 3], 10)  # == np.array([10, 12, 13])
add([1, 2, 3], 10, xnp=jnp)  # == jnp.array([10, 12, 13])
add([1, 2, 3], tf.constant(10))  # == tnp.array([10, 12, 13])

You can add a xnp: enp.NpModule = ... kwarg to your function which will be automatically assigned to the auto-infered xnp:

@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray, *, xnp: enp.NpModule = ...):
  return xnp.add(x, y)


add(1, [1, 2, 3])  # Inside the function, `xnp=np`
add(tf.constant(1), tf.constant(2))  # Inside the function, `xnp=tnp`

DType checking

There are 2 levels of checking:

  • Using type union: IntArray (accepts ui8, i32, i64, ...), FloatArray (accepts f32, bf16, ...)
  • Using specific type: f32, ui8, ...

Using type unions allows your functions to support quantization, ...

Shape checking

Currently, shape checking is not yet supported (but in project).