# Coordax and Xarray


**Key concepts:**
* {py:meth}`~coordax.Field.to_xarray` and {py:func}`~coordax.from_xarray` allow conversion between `Field` and `xarray.DataArray` objects
* Custom coordinates need to implement serialization methods

Xarray is a tool of choice for many serialization and visualization standards. To make tapping into this ecosystem easy, Coordax provides easy and lossless (round-trip) conversion between `xarray.DataArray` and `Field`.

Let's start with a simple example:


In [None]:
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np

xc = cx.LabeledAxis('x', np.linspace(0, 1, 5))
yc = cx.LabeledAxis('y', np.linspace(0, 1, 10))

fn = lambda x, y: jnp.exp(-(x-0.5)**2) * jnp.sin(15 * x * y)
field = cx.cmap(fn)(xc.fields['x'], yc.fields['y'])
dataarray = field.to_xarray()

print(f'{field=}')
print()
print(f'{dataarray=}')

We can restore a `Field` object from a `DataArray` using {py:meth}`~coordax.Field.from_xarray`.

In [None]:
restored = cx.from_xarray(dataarray)

assert field.coordinate == restored.coordinate
np.testing.assert_allclose(restored.data, field.data)

(serialization-custom-coordinates)=
## Serialization of custom coordinates

How do we get back coordinates other than {py:class}`~coordax.LabeledAxis`?

To support restoration of different coordinate objects, `Coordinate` classes can implement custom `from_xarray` method. Then, by providing candidate classes to `Field.from_xarray` (via coord_types arg), an appropriate coordinate will be instantiated.

As an example, here's a simple example of implementing a `Coordinate` that only matches uniformly-spaced xarray coordinates:

In [None]:
from typing import Self
import xarray

@jax.tree_util.register_static
class UniformAxis(cx.LabeledAxis):
  """UniformAxis with from_xarray implemented."""

  @classmethod
  def from_xarray(
      cls, dims: tuple[str, ...], coords: xarray.Coordinates
  ) -> Self | cx.NoCoordinateMatch:
    dim = dims[0]
    name = dim  # attempt to use given name.
    if coords[name].ndim != 1:
      return cx.NoCoordinateMatch('UniformAxis coordinate is not a 1D array')

    got = coords[name].data
    steps = np.diff(got)
    if not np.allclose(np.max(steps), np.min(steps), rtol=1e-6):
      return cx.NoCoordinateMatch(
          f'UniformAxis should have uniform spacing, got {steps=}'
      )
    return cls(name=name, ticks=got)

In [None]:
def make_uniform_axis(name, size, length):
  delta = length / size
  centers = np.linspace(delta / 2, length - delta / 2, size)
  return UniformAxis(name, centers)

xc = make_uniform_axis('x', 20, 1)
yc = make_uniform_axis('y', 30, 1)

fn = lambda x, y: (jnp.exp(-(x-0.5)**2) * jnp.sin(15 * x * y))
f = cx.cmap(fn)(xc.fields['x'], yc.fields['y'])
da = f.to_xarray()

da.plot(x='x', y='y')

To create a `Field` with our custom coordinate from an `xarray.DataArray`, you need to explicitly
provide it in the `coord_types` argument to {py:meth}`~coordax.Field.from_xarray`.

In [None]:
cx.from_xarray(da, coord_types=[cx.SizedAxis, UniformAxis])