# Coordinate Objects



**Key concepts:**
* `Coordinate` objects store coordinate-specific metadata and methods
* `Coordinate` classes are static [pytree nodes](https://docs.jax.dev/en/latest/working-with-pytrees.html#custom-pytree-nodes) which enable compile-time checks

While strings are convenient for simple labeling, `Coordinate` objects allow you to attach data (like tick values) to axes and enforce alignment similar to [Xarray](https://github.com/pydata/xarray).


## Using coordinates with fields

The primary way to use `Coordinate` objects is by passing them to `cx.wrap` when creating a `Field`.
This associates the coordinate with the corresponding dimension of the array.

In [None]:
import coordax as cx
import jax.numpy as jnp

# Define a coordinate
x_axis = cx.SizedAxis('x', 5)

# Wrap an array using the coordinate
# The coordinate 'x_axis' is associated with the first dimension (size 5)
f = cx.wrap(jnp.ones(5), x_axis)

print(f"Field dims: {f.dims}")
print(f"Field axes: {f.axes}")

## Standard coordinate types

Coordax provides several built-in coordinate types:
* {py:class}`~coordax.SizedAxis`: Minimal coordinate, only checks size.
* {py:class}`~coordax.LabeledAxis`: Stores tick values (e.g. grid points) and checks them for equality.
* {py:class}`~coordax.DummyAxis`: Placeholder for dimensions without associated coordinate values.
* {py:class}`~coordax.Scalar`: Zero-dimensional sentinel coordinate for scalars.


### SizedAxis

`SizedAxis` is the simplest coordinate type. It only checks that the dimension size matches. It does not carry any additional data.

In [None]:
import numpy as np

x = cx.SizedAxis('x', 5)
print(x)
f = cx.wrap(jnp.ones(5), x)
print(f.dims)

### LabeledAxis

`LabeledAxis` associates a dimension with a 1D array of tick values (e.g. grid points or labels). It checks for equality of these values when aligning fields.

In [None]:
ticks = np.linspace(0, 1, 5)
y = cx.LabeledAxis('y', ticks)
print(y)
print(y.fields['y'])  # Access the coordinate field

### Scalar

`Scalar` is a special coordinate for 0-dimensional data. It has no dimensions and no shape.

In [None]:
scalar = cx.Scalar()
print(scalar)
print(scalar.shape)

### DummyAxis

`DummyAxis` is the placeholder coordinate created when a coordinate is necessary, but does not exist on a `Field` (e.g., because it was only indicated with a string).

You can construct it explicitly, but generally don't need to. Note that `DummyAxis` coordinates are automatically dropped when used on a `Field`.

In [None]:
dummy = cx.DummyAxis('d', 5)
print(dummy)
f_dummy = cx.wrap(np.zeros(5), dummy)
print(f_dummy.dims)

### CartesianProduct

In general, `Coordinate` can include multiple dimensions. One special instance of such coordinate is a {py:class}`~coordax.CartesianProduct`, that simply bundles multiple `Coordinate` primitives together. The {py:func}`~coordax.compose_coordinates` helper is the most convenent way to create a `CartesianProduct`.

In [None]:
x_axis, y_axis = cx.SizedAxis('x', 6), cx.SizedAxis('y', 7)
xy_coord = cx.compose_coordinates(x_axis, y_axis)
print(xy_coord)
print(f'{xy_coord.dims=}')
print(f'{xy_coord.shape=}')

`CartesianProduct` is also used for representing the full set of coordinates associated with a Field via {py:attr}`coordax.Field.coordinate <Field.coordinate>` or {py:class}`~coordax.get_coordinate`.

In [None]:
field = cx.wrap(np.zeros((6, 7)), x_axis, y_axis)
field.coordinate

Multi-dimensional coordinates can be used the same way to `wrap`, `tag`, `untag`, etc.

In [None]:
x_axis, y_axis = cx.SizedAxis('x', 6), cx.SizedAxis('y', 7)
xy_coord = cx.compose_coordinates(x_axis, y_axis)
f = cx.wrap(np.ones((3, 6, 7)), 'batch', xy_coord)
print(f'{f.dims=}')
print(f'{f.untag(xy_coord).dims=}')

## Coordinate checks

Most manipulations on `Field` objects require exact coordinate alignment. When employing more complex coordinate objects that carry information beyond name and shape, this provides a powerful check to catch alignment and coordinate mismatch errors. In this block we will use standard {py:class}`~coordax.SizeAxis` and {py:class}`~coordax.LabeledAxis` coordinates to demonstrate alignment checks and will show how to implement custom coordinates in the following section.

This coordinate equality rule is relaxed for arguments passed to `tag`, `untag`, `order_as` etc. Passing a dimension name (`str`) is considered sufficient to express the user intent (assuming a coordinate with matching name is present).

In [56]:
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np

xc, yc = cx.SizedAxis('x', 2), cx.SizedAxis('y', 3)
f_xy = cx.wrap(np.arange(xc.size * yc.size).reshape((xc.size, yc.size)), xc, yc)

f_yx = f_xy.order_as(yc, xc)  # works - we use same coordinates.
also_f_yx = f_xy.order_as('y', 'x')  # also works

x_grid = cx.LabeledAxis('x', np.linspace(0, np.pi, 2))
y_grid = cx.LabeledAxis('y', np.linspace(0, 1, 3))
try:
  f_xy.order_as(y_grid, x_grid)  # raises, coordinates are different

except Exception as e:
  print(f'{type(e).__name__}: {e}')

The {py:class}`~coordax.LabeledAxis` equality includes a check on coordinate ticks, so if coordinates differ in the exact placement of the tick values, an error is raised. This is particularly relevant for numerical methods where fields could have offsets within computational cells.

In [None]:
x_bounds, dx = np.linspace(0, 2 * np.pi, 10, endpoint=False, retstep=True)
x_centers = np.linspace(dx / 2, 2 * np.pi - dx / 2, 10)
x_grid_bounds = cx.LabeledAxis('x', x_bounds)
x_grid_centers = cx.LabeledAxis('x', x_centers)

f = cx.wrap(np.ones(10), x_grid_centers)

assert x_grid_bounds.dims == x_grid_centers.dims  # dims are the same.
assert x_grid_bounds.shape == x_grid_centers.shape  # same, compatible shape.
assert x_grid_bounds != x_grid_centers  # not equal, since tick values differ.

try:
  f.untag(x_grid_bounds)  # raises, coordinates are different
except Exception as e:
  print(f'{type(e).__name__}: {e}')

## Custom coordinates

Users can also define custom coordinates by subclassing {py:class}`~coordax.Coordinate`.
This is useful for propagating additional metadata and associating custom methods that propagate automaticallys.

{py:class}`~coordax.Coordinate` objects represent one or more *Array* axes, specifying their `names`, `shape` and potentially providing additional values and methods associated with the coordinate. Their core properties include:
* `dims` - tuple of dimension names
* `shape` - shape of the coordinate object
* `fields` - holds named supporting values

Lets defined a custom coordinate class:

In [None]:
import dataclasses

@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True)
class UniformAxis(cx.Coordinate):
  """Cell centered coordinate with uniform discretization of (0, `length`)."""

  name: str
  size: int
  length: float

  @property
  def dims(self):
    return (self.name,)

  @property
  def shape(self) -> tuple[int, ...]:
    return (self.size,)

  @property
  def fields(self):
    delta = self.length / self.size
    cell_centers = np.linspace(delta / 2, self.length - delta / 2, self.size)
    return {self.name: cx.wrap(cell_centers, self)}

In [None]:
z_centers = UniformAxis('z', 10, np.pi * 2)
print(z_centers)
z_centers.fields['z']  # access to supporting value from `@fields`.