In [1]:
import xarray as xr
import numpy as np

xr.set_options(display_style="text", display_expand_data=False)
rng = np.random.default_rng(seed=0)

In [2]:
lat = np.arange(-90, 90, 0.1)
lon = np.arange(0, 360, 0.1)

lat2d, lon2d = np.meshgrid(lat, lon)

data = rng.standard_normal(lat2d.shape)
arr = xr.DataArray(
    data,
    dims=("x", "y"),
    coords={"lat": (("x", "y"), lat2d), "lon": (("x", "y"), lon2d)},
)
arr

In [3]:
arr.xindexes

Indexes:

In [4]:
from scipy.spatial import KDTree


class KDTreeIndex(xr.core.indexes.Index):
    def __init__(self, data, names, dim, **options):
        self.names = names
        self.dim = dim
        self.shape = data.shape
        self.index = KDTree(data.reshape(-1, self.shape[-1]), **options)

    @classmethod
    def from_variables(cls, variables, **options):
        data = np.concatenate(
            [var.data[..., None] for var in variables.values()], axis=-1
        )
        dims = {var.dims for var in variables.values()}
        if len(dims) != 1:
            raise ValueError("variables need to have the same dimensions")
        (dims,) = dims
        names = list(variables.keys())
        return cls(data, names, dims, **options)

    def sel(self, indexers):
        unknown_dimensions = set(indexers) - set(self.names)
        if unknown_dimensions:
            raise ValueError("unknown dimensions:", list(unknown_dimensions))

        points = np.concatenate(
            [indexers[name][..., None] for name in self.names],
            axis=-1,
        )

        distances, indices_ = self.index.query(points)
        indices = np.unravel_index(indices_, self.shape[:-1])

        isel_indexers = {
            dim: xr.DataArray(
                data, coords={"distance": ("points", distances)}, dims="points"
            )
            for dim, data in zip(self.dim, indices)
        }

        return xr.core.indexes.IndexSelResult(isel_indexers)

In [5]:
tree = KDTreeIndex.from_variables(arr.coords)
tree

<__main__.KDTreeIndex at 0x7f4606b8b8b0>

In [6]:
new_arr = xr.DataArray(
    arr.variable,
    coords={
        "lat": xr.Variable(("x", "y"), lat2d),
        "lon": xr.Variable(("x", "y"), lon2d),
        "x": xr.Variable("x", np.arange(lat2d.shape[0])),
        "y": xr.Variable("y", np.arange(lat2d.shape[1])),
    },
    indexes={
        "lat": tree,
        "lon": tree,
        "x": xr.core.indexes.PandasIndex.from_variables({"x": xr.Variable("x", np.arange(lat2d.shape[0]))}),
        "y": xr.core.indexes.PandasIndex.from_variables({"y": xr.Variable("y", np.arange(lat2d.shape[1]))}),
    },
    fastpath=True,
)
new_arr

In [7]:
new_arr.xindexes

Indexes:
lat: <__main__.KDTreeIndex object at 0x7f4606b8b8b0>
lon: <__main__.KDTreeIndex object at 0x7f4606b8b8b0>
x: <xarray.core.indexes.PandasIndex object at 0x7f4606b8aea0>
y: <xarray.core.indexes.PandasIndex object at 0x7f45fd3ddc20>

# try the possible operations

In [8]:
indexers = {
    "lat": np.array([0.742, 10.213, 17.648]),
    "lon": np.array([8.873, 3.12, 9.15]),
}

`sel`

In [9]:
new_arr.sel(indexers)

`isel`


In [10]:
sliced = new_arr.isel(x=slice(5, 10), y=slice(10, 20))
sliced.xindexes

Indexes:
x: <xarray.core.indexes.PandasIndex object at 0x7f4606b25590>
y: <xarray.core.indexes.PandasIndex object at 0x7f4606b25220>

`roll`

In [11]:
rolled = new_arr.roll(x=5, y=10)
rolled.sel(indexers)  # for easier comparison

`stack`

In [12]:
stacked = new_arr.stack(z=("x", "y"))
stacked.xindexes

Indexes:
z: <xarray.core.indexes.PandasMultiIndex object at 0x7f45fd3e39e0>
x: <xarray.core.indexes.PandasMultiIndex object at 0x7f45fd3e39e0>
y: <xarray.core.indexes.PandasMultiIndex object at 0x7f45fd3e39e0>

`copy`

In [13]:
new_arr.copy()

NotImplementedError: 