## Imports

Note, jaxfluids uses a classical approach of domain decomposition and ghostcells. I hoped that based on the computations follows data principle, sharding and fusion of operations, JAX could do better without manually handling this - but I might well be wrong.

In [24]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4" 


# os.environ["JAX_ENABLE_PGLE"] = "true"
# os.environ["XLA_FLAGS"] = "--xla_gpu_enable_latency_hiding_scheduler=true"

from typing import Optional

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp

from jax.sharding import PartitionSpec as P, NamedSharding

## Finite Difference Test

In [25]:
@partial(jax.jit, static_argnames=['axis'])
def finite_difference(array, axis):
    return (jax.lax.slice_in_dim(array, 1, -1, axis = axis) - jax.lax.slice_in_dim(array, 0, -2, axis = axis))

@partial(jax.jit, static_argnames=['indices', 'axis', 'zero_pad'])
def _stencil_add(
        input_array: jnp.ndarray,
        indices,
        factors,
        axis: int,
        zero_pad: bool = True
) -> jnp.ndarray:
    """
    Combines elements of an array additively
        output_i <- sum_j factors_j * input_array_{i + indices_j}

    By default, the output is zero-padded to the same shape as 
    the input array (as we handle boundaries via ghost cells in 
    the overall simulation code). This behavior can be disabled,
    then the output will have a different shape along the specified
    axis.

    Args:
        input_array: The array to operate on.
        indices: output_i <- sum_j factors_j * input_array_{i + indices_j}
        factors: output_i <- sum_j factors_j * input_array_{i + indices_j}
        axis: The axis along which to operate.
        zero_pad: Whether to zero-pad the output to have the same shape as the input.
        
    Returns:
        output_i <- sum_j factors_j * input_array_{i + indices_j}
    """

    num_cells = input_array.shape[axis]

    first_write_index = -min(0, min(indices))
    last_write_index = num_cells - max(0, max(indices))

    # for the first write index, the elements considered are
    first_handled_indices = tuple(first_write_index + index for index in indices)

    # for the last write index, the elements considered are
    last_handled_indices = tuple(last_write_index + index for index in indices)

    output = (
        sum(
            factor * jax.lax.slice_in_dim(
                input_array,
                first_handled_index,
                last_handled_index,
                axis = axis
            )
            for factor, first_handled_index, last_handled_index in zip(
                factors, first_handled_indices, last_handled_indices
            )
        )
    )

    if zero_pad:
        result = jnp.zeros_like(input_array)
        selection = (
            (slice(None),) * axis +
            (slice(first_write_index, last_write_index),) +
            (slice(None),)*(input_array.ndim - axis - 1)
        )
        result = result.at[selection].set(output)
        return result
    else:
        return output

@partial(jax.jit)
def dummy_fluid_code(array):
    # "time steps"
    for _ in range(10):
        # "axes"
        for axis in range(1, array.ndim):
            array += 0.001 * _stencil_add(array, (1, -1), (1.0, -1.0), axis = axis)

        # some non-linear computation
        # array = jnp.sin(array)
        
    return array

## Unsharded array

In [26]:
size_per_dim = 470
unsharded_array = jax.random.normal(jax.random.key(0), (5, size_per_dim, size_per_dim, size_per_dim))

In [27]:
# jnp.sin(unsharded_array)
# %timeit -n 10 -r 10 jnp.sin(unsharded_array).block_until_ready()

In [28]:
# finite_difference(unsharded_array, 1)
# %timeit  -n 10 -r 10 finite_difference(unsharded_array, 1).block_until_ready()

In [29]:
dummy_fluid_code(unsharded_array)
%timeit -n 5 -r 5 dummy_fluid_code(unsharded_array).block_until_ready()

94.3 ms ± 7.74 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


## Sharded array

In [30]:
mesh = jax.make_mesh((1, 2, 2, 1), ('vars', 'x', 'y', 'z'))
sharding = NamedSharding(mesh, P('vars', 'x', 'y', 'z'))

In [31]:
sharded_array = jax.device_put(unsharded_array, sharding)
jax.debug.visualize_array_sharding(sharded_array[0, :, :, 0])

In [32]:
jnp.sin(sharded_array)
%timeit -n 10 -r 10 jnp.sin(sharded_array).block_until_ready()

1.15 ms ± 58.5 μs per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [33]:
finite_difference(sharded_array, 1)
%timeit -n 10 -r 10 finite_difference(sharded_array, 1).block_until_ready()

3.72 ms ± 300 μs per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [35]:
dummy_fluid_code(sharded_array)
%timeit -n 5 -r 5 dummy_fluid_code(sharded_array).block_until_ready()

93.9 ms ± 9.32 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
