In [1]:
import xarray as xr
import dask.array as dsa
import numpy as np
import xgcm

In [2]:
# make test dataset

lon_g = np.arange(0, 360.) 
lon_c = lon_g + 0.5

lat_g = np.arange(0, 180)
lat_c = lat_g + 0.5

ny = len(lat_c)
nx = len(lon_c)
nt = 365

dims = ('time', 'lat_c', 'lon_c')
shape = (nt, ny, nx)
chunks = (1, ny, nx)

ds = xr.Dataset(
    {
        "foo": (dims, dsa.random.random(shape, chunks)),
        "bar": (dims, dsa.random.random(shape, chunks)),
    },
    coords = {
        "lat_c": ("lat_c", lat_c, {"axis": "Y"}),
        "lat_g": ("lat_g", lat_g, {"axis": "Y", "c_grid_axis_shift": -0.5}),
        "lon_c": ("lon_c", lon_c, {"axis": "X"}),
        "lon_g": ("lon_g", lon_g, {"axis": "X", "c_grid_axis_shift": -0.5}),        
    }
)
ds

Unnamed: 0,Array,Chunk
Bytes,180.45 MiB,506.25 kiB
Shape,"(365, 180, 360)","(1, 180, 360)"
Count,365 Tasks,365 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 180.45 MiB 506.25 kiB Shape (365, 180, 360) (1, 180, 360) Count 365 Tasks 365 Chunks Type float64 numpy.ndarray",360  180  365,

Unnamed: 0,Array,Chunk
Bytes,180.45 MiB,506.25 kiB
Shape,"(365, 180, 360)","(1, 180, 360)"
Count,365 Tasks,365 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,180.45 MiB,506.25 kiB
Shape,"(365, 180, 360)","(1, 180, 360)"
Count,365 Tasks,365 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 180.45 MiB 506.25 kiB Shape (365, 180, 360) (1, 180, 360) Count 365 Tasks 365 Chunks Type float64 numpy.ndarray",360  180  365,

Unnamed: 0,Array,Chunk
Bytes,180.45 MiB,506.25 kiB
Shape,"(365, 180, 360)","(1, 180, 360)"
Count,365 Tasks,365 Chunks
Type,float64,numpy.ndarray


In [3]:
# set up grid

# this is simple because it just uses periodic boundary conditions by default
# boundary conditions complicated things
grid = xgcm.Grid(ds)
grid

<xgcm.Grid>
X Axis (periodic, boundary=None):
  * center   lon_c --> left
  * left     lon_g --> center
Y Axis (periodic, boundary=None):
  * center   lat_c --> left
  * left     lat_g --> center

In [4]:
print(len(ds.foo.data.dask))

365


In [5]:
# test problem: shift both center points to corner point, multiply, then shift back

foo_corner = grid.interp(ds.foo, ['X', 'Y'])
print(len(foo_corner.data.dask))

14600


In [6]:
bar_corner = grid.interp(ds.bar, ['X', 'Y'])
product_corner = foo_corner * bar_corner
product_center = grid.interp(product_corner, ['X', 'Y'], boundary='fill')
print(len(product_center.data.dask))

70810


In [7]:
%time product_center_computed = product_center.compute()

CPU times: user 11.8 s, sys: 1.87 s, total: 13.7 s
Wall time: 10.9 s


In [8]:
# faster implementation

def _interp(a, n, axis):
    return 0.5 * (a + np.roll(a, n, axis=axis))

def do_operations_with_numpy(foo, bar):
    foo_corner = _interp(_interp(foo, -1, -1), -1, -2)
    bar_corner = _interp(_interp(bar, -1, -1), -1, -2)
    product_corner = foo_corner * bar_corner
    product_center = _interp(_interp(product_corner, 1, -1), 1, -2)
    return product_center

In [9]:
data_product_center = dsa.map_blocks(
    do_operations_with_numpy,
    ds.foo.data,
    ds.bar.data,
    meta = np.array((), dtype=ds.foo.dtype)
)
len(data_product_center.dask)

1095

In [10]:
product_center_wrapped = xr.DataArray(data_product_center, dims=product_center.dims, coords=product_center.coords)
%time product_center_wrapped_computed = product_center_wrapped.compute()

CPU times: user 2.61 s, sys: 479 ms, total: 3.09 s
Wall time: 1.4 s


In [11]:
xr.testing.assert_allclose(product_center_wrapped_computed, product_center_computed)