# JAX Transformations



**Key concepts:**
* {py:class}`~coordax.Field` is a pytree and is compatible with shape-preserving JAX transformation
* `Field` with leading positional axes are compatible with shape-altering transformations

Since `Field` is implemented as a pytree, operations like automatic differentiation, JIT are compatible out of the box.

In [None]:
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np

def loss(x, y):
  return cx.cmap(jnp.sum)((x - y)**2).data

grad = jax.grad(loss)(cx.wrap(np.ones(5)), cx.wrap(np.zeros(5)))
also_grad = jax.jit(jax.grad(loss))(cx.wrap(np.ones(5)), cx.wrap(np.zeros(5)))
also_grad

Operations that modify the shape of the pytree leaves (e.g. `jax.vmap`, `jax.lax.scan`) are also allowed, as long as they do not violate the coordinates metadata. This requirement limits their application to `Field` instances over leading positional axes, which covers most common use-cases:

1. `jax.vmap(fn_on_field)` - explicit batching
2. `jax.lax.scan(fn, fields_to_scan_over)` - scanning over `Field` entries

## Explicit vectorization with `jax.vmap`

{py:func}`~coordax.cmap` is the most convenient way to vectorize functions in Coordax,
but `Field` is also compatible with `jax.vmap`, as long as the leading dimension is unlabeled:

In [None]:
data = np.arange(10).reshape((2, 5))
f = cx.wrap(data, None, 'y')

def identity_with_checks(x: cx.Field) -> cx.Field:
  assert x.dims == ('y',)  # under vmap `x` will be slices of `f` above.
  return x

same_as_f = jax.vmap(identity_with_checks)(f)
same_as_f

`vmap` over a labeled dimension raises an error:

In [None]:
try:
  jax.vmap(identity_with_checks)(f.tag('x'))
except Exception as e:
  print(f'{type(e).__name__}: {e}')

## Scanning over `Field` entries

Similarly, `Field` supports `jax.lax.scan`, as long all dimension scanned over are unlabeled:

In [None]:
data = np.arange(10).reshape((2, 5))
f = cx.wrap(data, None, 'y')

def identity_body_with_checks(unused_c, x: cx.Field) -> cx.Field:
  assert x.dims == ('y',)  # under vmap `x` will be slices of `f` above.
  return None, x

_, same_as_f = jax.lax.scan(identity_body_with_checks, init=None, xs=f)
same_as_f

And likewise, scanning over labeled dimensions raises an error:

In [None]:
try:
  jax.lax.scan(identity_body_with_checks, init=None, xs=f.tag('x'))
except Exception as e:
  print(f'{type(e).__name__}: {e}')

## `jax.tree.map` and `Field`

There are two main options how `jax.tree.map` can interact with `Field` instances:
1. Mapping over `Field` leaves by setting `is_leaf=cx.is_field` in the `map` function
2. Mapping over underlying data using the default `map` behavior

The former is a safe approach as metadata for each `Field` is explicitly taken care of by the calling function.

The later, especially when performing shape modifications, should be exercised with great care.

Coordax implements shape checks to catch issues when the underlying data was modified to be no longer compatible with the coordinate labels, but erroneous transformations that result in Arrays of compatible shape cannot be identified in general.

Similar to the `jax.vmap` and `jax.lax.scan`, functions that trim or insert leading positional dimensions are supported, e.g.,

- Adding new leading positional shape using jax.tree.map:

In [None]:
data = np.arange(10).reshape((2, 5))
f = cx.wrap(data, 'x', 'y')
# note that here tree.map operates on the underlying Array values.
with_leading_axis = jax.tree.map(lambda x: x[np.newaxis, ...], f)
with_leading_axis

- Trimming leading positional shape using jax.tree.map:

In [None]:
data = np.arange(10).reshape((1, 2, 5))
f = cx.wrap(data, None, 'x', 'y')
without_leading_axis = jax.tree.map(lambda x: x[0, ...], f)
without_leading_axis

- Modifying data without changing shape using jax.tree.map:

In [None]:
data = np.arange(4).reshape((1, 2, 2))
f = cx.wrap(data, 'b', 'x', 'y')
double_f = jax.tree.map(lambda x: x * 2, f)
double_f

The function above, however, cannot be distinguished from an accidental transpose of the last 2 axes, which could result in out of sync coordinates (unless motivated by the desired computation):

In [None]:
data = np.arange(4).reshape((2, 2))
f = cx.wrap(data, 'x', 'y')
yx_f_labeled_as_xy = jax.tree.map(lambda x: x.transpose(), f)
yx_f_labeled_as_xy