# Coordinate Map (cmap)



**Key concepts:**
* {py:func}`~coordax.cmap` (coordinate map) plays a central role in connecting dimension-aware representation to array functions
* `cmap` transforms a function `fn` operating on `Array` inputs to support `Field` args, with `fn` applied over *positional axes* and vectorized over *named axes*
* Return values have axis order determined by the order of appearance or the `out_axes` kwarg
* Alternatively, {py:func}`~coordax.cpmap` is a convenient short-cut for coordinate _preserving_ mappings
* Built-in binary/unary operations on `Field` are implemented using `cmap`

If you're familiar with Xarray, you can think of {py:func}`~coordax.cmap` as a super-charged version of [`xarray.apply_ufunc`](https://tutorial.xarray.dev/advanced/apply_ufunc/simple_numpy_apply_ufunc.html) enabled by [`jax.vmap`](https://docs.jax.dev/en/latest/automatic-vectorization.html).

If we have a function working with `Array` inputs that we want to apply to `Field` data, we can use the `cmap` to transform it to be compatible with `Field` arguments. Let's start with standard `jnp.add`:

In [None]:
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np

rng = np.random.RandomState(0)
f_rand = cx.wrap(rng.uniform(size=3), 'x')
f_arange = cx.wrap(np.arange(3), 'x')
fs_added = cx.cmap(jnp.add)(f_rand, f_arange)
fs_added

In the example above, `jnp.add` was vectorized over named dimension 'x' and applied to the remaining (empty) positional axes. This is reflected in the 'x' annotation of the output - vectorization over named axes preserves their coordinates.

If we pass inputs with `x` dimension untagged, the transformed function will applied to 1d arrays corresponding to the positional shape. In the case of `jnp.add`, the numerical result is the same since `jnp.add == jax.vmap(jnp.add)`.

In [None]:
added = cx.cmap(jnp.add)(f_rand.untag('x'), f_arange.untag('x'))
assert (added.data == fs_added.data).all()
added

Note that `added` no longer has 'x' coordinate, since inputs had no coordinates.

Let's look at a more complex example of computing FFT over a single axis with multi-dimensional inputs

In [None]:
def fft(x):
  assert x.ndim == 1  # make sure method is applied to vectors.
  return jnp.fft.fft(x)

xc = cx.SizedAxis('x', 16)
rng = np.random.RandomState(0)
f_x = cx.wrap(rng.uniform(size=(3, 16, 4)), 'batch', xc, 'z')

f_kx = cx.cmap(fft)(f_x.untag(xc))
f_kx

The returned `f_kx` contains values where each 'x' slice was transformed using our function `fft`.

Note that `f_kx` has dimensions `(None, 'batch', 'z')`. By default, all named dimensions are vectorized and placed at the end of the result in the order in which they appear, with positional dimensions placed at the beginning of the result.

The resulting axis order can be explicitly controlled via `out_axes` argument in `cmap`, which takes a dictionary mapping axis names to their output axis indices.

Let's look at a few more example to get a feel for this behavior:

In [None]:
out = cx.cmap(fft)(f_x.order_as('z', ...).untag(xc))
assert out.dims == (None, 'z', 'batch')  # 'z', 'batch' is order of appearance.

out = cx.cmap(fft, out_axes={'batch': 0, 'z': 2})(f_x.untag(xc))
assert out.dims == ('batch', None, 'z')  # explicitly indicate out_axes order.

In the examples above `None` corresponds to the positional axis returned by `fft`. Other choice of function could reduce over the input dimension (e.g. if we replace `jnp.fft.fft` with `jnp.sum`) or could create more positional axes (e.g. `jnp.cov`). This would be reflected in the positional shape of the output (with no positional axes for `jnp.sum` and two for `jnp.cov`).

This mechanics forms the primary design pattern of locally positional axes computation:
1. Desired axes are exposed using `untag`
2. Computation is performed using `cx.cmap(fn, ...)(untagged_inputs)`
3. New coordinates are added using `tag` (if needed)

In [None]:
f_x = f_x.untag(xc)
f_kx = cx.cmap(fft, out_axes=f_x.named_axes)(f_x).tag('kx')
f_kx

## Transforming functions with multiple arguments

`cmap` supports transforming primitives that involve multiple inputs/outputs (in fact we have already looked at the `jnp.add` example). This provides a convenient mechanism for automatic vectorization based on function inputs, but lack of care can lead to unexpected results.

Lets consider computing log likelihood (and log pdf) of a sample predicted by a model. Our function would need to accept a prediction sample and parameters of a distribution.

In [None]:
import jax.scipy.stats as jsp_stats

def log_likelihood(sample, mean, std_dev):
  log_pdf = jsp_stats.norm.logpdf(sample, loc=mean, scale=std_dev)
  return jnp.sum(log_pdf), log_pdf

samples_data = jax.random.normal(jax.random.key(0), shape=(10, 3))
samples_data = samples_data * jnp.arange(10)[:, None] + jnp.arange(10)[:, None]
samples = cx.wrap(samples_data, 'batch', 'x')

means = jnp.ones(3) * 2  # parameters of distribution
std_dev = jnp.array([1.9, 2.0, 2.1])
inputs = samples.untag('x')
ll, log_pdf = cx.cmap(log_likelihood)(inputs, means, std_dev)
ll, log_pdf

In the example above we passed in jax.Array inputs for parameters of the distribution. We could obtain the same result by wrapping those arguments into `Field`:

In [None]:
cx.cmap(log_likelihood)(inputs, cx.wrap(means), cx.wrap(std_dev))

However, if we accidentally pass argument with named dimension attached, the whole computation will automatically vectorize over it!

In [None]:
cx.cmap(log_likelihood)(inputs, means, cx.wrap(std_dev, 'x_axis'))

## Coordinate Preserving Map (`cpmap`)

{py:func}`~coordax.cpmap` is a shorthand for `cmap(..., out_axes='same_as_input')`. It applies a function over positional axes while preserving the dimensionality and the coordinate order of the inputs.

This is particularly useful when you want the output field to have the same structure as the input field, without having to manually specify `out_axes`.

In [None]:
x = cx.SizedAxis('x', 3)
y = cx.SizedAxis('y', 2)
# Create a field with mixed named and positional axes
f = cx.wrap(jnp.ones((3, 4, 2)), x, None, y)
print(f'Input dims: {f.dims}')

# Standard cmap puts named axes at the end
print(f'cmap dims: {cx.cmap(lambda x: x)(f).dims}')

# cpmap preserves the order
print(f'cpmap dims: {cx.cpmap(lambda x: x)(f).dims}')

## Arithmetic operations on `Field`

`Field` class implements default methods like addition, multiplication, division. These methods transform standard array methods using `cmap` and hence automatically align coordinates and broadcast data where necessary:

In [None]:
x_grid = cx.SizedAxis('x', 5)
y_grid = cx.SizedAxis('y', 3)

zero_field = cx.wrap(np.zeros((5, 3)), x_grid, y_grid)
ones_field = cx.wrap(np.ones((5, 3)), x_grid, y_grid)

zero_field + ones_field, zero_field * ones_field

In [None]:
zero_field + ones_field.order_as('y', 'x')  # still works due to auto alignment

In [None]:
zero_field.order_as('y', 'x') + ones_field  # works, but note y, x result order.

Vectorization over named dimension results in effective broadcasting:

In [None]:
zero_field + cx.wrap(np.arange(zero_field.named_shape['x']), 'x')

The same gotcha about output axis order applies to arithmetic operations. For instance adding `Field` instances with trailing positional axis will move them the the first dimension as in the case of `cmap(jnp.add)`:

In [None]:
addition = zero_field.untag('y') + ones_field.untag('y')
addition  # Note how 'x' dim got moved to the end as the only vectorized axis.