Trying out Jax interfacing with Pace

In [1]:
from jax import config, jit, vjp

config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np

from gt4py.cartesian import gtscript
from gt4py.cartesian.gtscript import (
    __INLINED,
    PARALLEL,
    compile_assert,
    computation,
    horizontal,
    interval,
    region,
)

from pace.fv3core.stencils.xppm import compute_x_flux
import pace.util
from pace.dsl.stencil import StencilFactory
from pace.util import X_DIM, Y_DIM, Z_DIM
from pace.util.grid import DampingCoefficients, GridData
from pace.dsl.typing import FloatField, FloatFieldIJ, Index3D, Float

from pace.dsl.stencil import GridIndexing, StencilConfig, StencilFactory
from pace.dsl.stencil_config import CompilationConfig

2023-11-06 22:30:45|INFO|rank 0|pace.util.logging:Constant selected: ConstantVersions.GFS


set initial parameters

In [8]:
nx = 6
ny = 6
nz = 1
nhalo = 1
backend = "numpy"

Start with laplacian

In [9]:
@gtscript.function
def lap(field: FloatField):
    return field[1,0,0] + field[0,1,0] + field[-1,0,0] + field[0,-1,0] - 4.*field[0,0,0]

def laplacian(
    field: FloatField,
    result: FloatField
):
    with computation(PARALLEL), interval(...):
        result = lap(field)

def laplap(
    field: FloatField,
    result: FloatField
):
    with computation(PARALLEL), interval(...):
        tmp = lap(field)
        result = lap(tmp)

Create the class we'll use a la Pace

In [10]:
class Laplacian:
    def __init__(
        self,
        stencil_factory: StencilFactory,
    ):
        idx = stencil_factory.grid_indexing
        self._lap_stencil = stencil_factory.from_origin_domain(
            func=laplacian,
            origin=idx.origin_compute(),
            domain=idx.domain_compute(),
        )
    def __call__(
        self,
        q_in: FloatField,
        q_out: FloatField,
    ):
        q_in, q_out = self._lap_stencil(q_in, q_out)
        return q_in, q_out

In [11]:
compilation_config = CompilationConfig(backend=backend)

stencil_config = StencilConfig(
    compare_to_numpy=False,
    compilation_config=compilation_config,
)

grid_indexing = GridIndexing(
    (nx, ny, nz),
    nhalo,
    False,
    False,
    False,
    False,
)

stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)

lap = Laplacian(stencil_factory)

In [12]:
sizer = pace.util.SubtileGridSizer.from_tile_params(nx, ny, nz, nhalo, extra_dim_lengths={}, layout=(1,1))
quantity_factory = pace.util.QuantityFactory.from_backend(sizer=sizer, backend=backend)

field_in = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], units="unknown", dtype=float)
field_out = quantity_factory.zeros([X_DIM, Y_DIM, Z_DIM], units="unknown", dtype=float)

raw_data = np.array([[0., 0., 1., 2., 2., 1., 0., 0.],
                     [0., 1., 2., 3., 3., 2., 1., 0.],
                     [1., 2., 3., 4., 4., 3., 2., 1.],
                     [2., 3., 4., 5., 5., 4., 3., 2.],
                     [2., 3., 4., 5., 5., 4., 3., 2.],
                     [1., 2., 3., 4., 4., 3., 2., 1.],
                     [0., 1., 2., 3., 3., 2., 1., 0.],
                     [0., 0., 1., 2., 2., 1., 0., 0.]])

field_in.data[nhalo:-nhalo-1,nhalo:-nhalo-1, 0] = raw_data

ValueError: could not broadcast input array from shape (8,8) into shape (6,6)

In [13]:
ret1, ret2 = lap(field_in, field_out)
print(np.squeeze(field_in.data[:,:,0]))
print(np.squeeze(field_out.data[:,:,0]))

ValueError: Origin for field field too small. Must be at least (1, 1, 0), is (0, 0, 0)

Now do XPPM

In [None]:
class JaxXPPM:
    def __init__(
        self,
        stencil_factory: StencilFactory,
        dxa,
        grid_type: int,
        iord,
        origin: Index3D,
        domain: Index3D,
    ):
        # Arguments come from:
        # namelist.grid_type
        # grid.dxa
        assert (grid_type < 3) or (grid_type == 4)
        self._dxa = dxa
        ax_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain)
        self._compute_flux_stencil = stencil_factory.from_origin_domain(
            func=compute_x_flux,
            externals={
                "iord": iord,
                "mord": abs(iord),
                "xt_minmax": True,
                "i_start": ax_offsets["i_start"],
                "i_end": ax_offsets["i_end"],
                "grid_type": grid_type,
            },
            origin=origin,
            domain=domain,
        )

    def __call__(
        self,
        q_in: FloatField,
        c: FloatField,
        q_mean_advected_through_x_interface: FloatField,
    ):
        """
        Determine the mean value per area of q_in to be advected along x-interfaces.

        This is done by integrating a piecewise-parabolic subgrid reconstruction
        of q_in along the x-direction over the segment of gridcell which
        will be advected.

        Multiplying this mean value by the area to be advected through the interface
        would give the flux of q through that interface.

        Args:
            q_in (in): scalar to be integrated
            c (in): Courant number (u*dt/dx) in x-direction defined on x-interfaces,
                indicates the fraction of the adjacent grid cell which will be
                advected through the interface in one timestep
            q_mean_advected_through_x_interface (out): defined on x-interfaces.
                mean value of scalar within the segment of gridcell to be advected
                through that interface in one timestep, in units of q_in
        """
        # in the Fortran version of this code, "x_advection" routines
        # were called "get_flux", while the routine which got the flux was called
        # fx1_fn. The final value was called xflux instead of q_out.
        self._compute_flux_stencil(
            q_in, c, self._dxa, q_mean_advected_through_x_interface
        )
        # bl and br are "edge perturbation values" as in equation 4.1
        # of the FV3 documentation