In [1]:
import xarray as xr
import numpy as np

In [2]:
lat = np.arange(-90, 90, 0.1)
lon = np.arange(0, 360, 0.1)

lat2d, lon2d = np.meshgrid(lat, lon)

data = np.random.randn(*lat2d.shape)
ds = xr.Dataset(
    {"data": (("x", "y"), data)},
    coords={"lat": (("x", "y"), lat2d), "lon": (("x", "y"), lon2d)},
)
ds

In [3]:
ds.xindexes

Indexes:

In [4]:
from scipy.spatial import KDTree


class KDTreeIndex(xr.core.indexes.Index):
    def __init__(self, data, names, dims, **options):
        self.names = names
        self.dims = dims
        self.shape = data.shape
        self.kdtree = KDTree(data.reshape(-1, self.shape[-1]), **options)

    @classmethod
    def from_variables(cls, variables, **options):
        data = np.concatenate(
            [var.data[..., None] for var in variables.values()], axis=-1
        )
        dims = {var.dims for var in variables.values()}
        if len(dims) != 1:
            raise ValueError("variables need to have the same dimensions")
        (dims,) = dims
        names = list(variables.keys())
        return cls(data, names, dims, **options)

    def sel(self, indexers):
        unknown_dimensions = set(indexers) - set(self.names)
        if unknown_dimensions:
            raise ValueError("unknown dimensions:", list(unknown_dimensions))

        points = np.concatenate(
            [indexers[name][..., None] for name in self.names],
            axis=-1,
        )

        distances, indices_ = self.kdtree.query(points)
        indices = np.unravel_index(indices_, self.shape[:-1])

        isel_indexers = {
            dim: xr.DataArray(
                data, coords={"distance": ("points", distances)}, dims="points"
            )
            for dim, data in zip(self.dims, indices)
        }

        return isel_indexers

In [5]:
tree = KDTreeIndex.from_variables(ds.coords)
tree

<__main__.KDTreeIndex at 0x7fde516ec280>

In [6]:
indexers = {
    "lat": np.array([0.742, 10.213, 17.648]),
    "lon": np.array([8.873, 3.12, 9.15]),
}
ds.isel(tree.sel(indexers))