Trying out Jax interfacing with Pace

In [None]:
from gt4py.cartesian import gtscript
from gt4py.cartesian.gtscript import (
    __INLINED,
    PARALLEL,
    compile_assert,
    computation,
    horizontal,
    interval,
    region,
)

from pace.fv3core.stencils.xppm import compute_x_flux
import pace.util
from pace.dsl.stencil import StencilFactory
from pace.util import X_DIM, Y_DIM, Z_DIM
from pace.util.grid import DampingCoefficients, GridData
from pace.dsl.typing import FloatField, FloatFieldIJ, Index3D

from jax import config, jit, vjp


config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np


set initial parameters

In [None]:
nx = 12
ny = 12
nz = 1
nhalo = 3

Create the class we'll use a la Pace

In [None]:
class JaxXPPM:
    def __init__(
        self,
        stencil_factory: StencilFactory,
        dxa,
        grid_type: int,
        iord,
        origin: Index3D,
        domain: Index3D,
    ):
        # Arguments come from:
        # namelist.grid_type
        # grid.dxa
        assert (grid_type < 3) or (grid_type == 4)
        self._dxa = dxa
        ax_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain)
        self._compute_flux_stencil = stencil_factory.from_origin_domain(
            func=compute_x_flux,
            externals={
                "iord": iord,
                "mord": abs(iord),
                "xt_minmax": True,
                "i_start": ax_offsets["i_start"],
                "i_end": ax_offsets["i_end"],
                "grid_type": grid_type,
            },
            origin=origin,
            domain=domain,
        )

    def __call__(
        self,
        q_in: FloatField,
        c: FloatField,
        q_mean_advected_through_x_interface: FloatField,
    ):
        """
        Determine the mean value per area of q_in to be advected along x-interfaces.

        This is done by integrating a piecewise-parabolic subgrid reconstruction
        of q_in along the x-direction over the segment of gridcell which
        will be advected.

        Multiplying this mean value by the area to be advected through the interface
        would give the flux of q through that interface.

        Args:
            q_in (in): scalar to be integrated
            c (in): Courant number (u*dt/dx) in x-direction defined on x-interfaces,
                indicates the fraction of the adjacent grid cell which will be
                advected through the interface in one timestep
            q_mean_advected_through_x_interface (out): defined on x-interfaces.
                mean value of scalar within the segment of gridcell to be advected
                through that interface in one timestep, in units of q_in
        """
        # in the Fortran version of this code, "x_advection" routines
        # were called "get_flux", while the routine which got the flux was called
        # fx1_fn. The final value was called xflux instead of q_out.
        self._compute_flux_stencil(
            q_in, c, self._dxa, q_mean_advected_through_x_interface
        )
        # bl and br are "edge perturbation values" as in equation 4.1
        # of the FV3 documentation