# NDPointIndex Approach

xarray includes [`NDPointIndex`](https://xarray-indexes.readthedocs.io/blocks/ndpoint.html) for **unstructured point data** (e.g., irregular grids, scattered observations). It uses a KD-tree for spatial nearest-neighbor queries.

This notebook explores whether `NDPointIndex` can solve the same problem as `NDIndex` for trial-based data with derived coordinates.

## Setup

In [None]:
import numpy as np
import xarray as xr
from linked_indices.example_data import trial_based_dataset

## What NDPointIndex is designed for

`NDPointIndex` is designed for **curvilinear grids** and **unstructured point clouds** where you have multiple coordinate variables that together define a point in N-dimensional space.

The classic example is a 2D grid with latitude and longitude coordinates that vary in both dimensions:

In [None]:
# Create a curvilinear grid (like ocean model output)
# The lat/lon coordinates vary in BOTH dimensions
shape = (5, 10)
lon = xr.DataArray(np.random.uniform(-180, 180, size=shape), dims=("y", "x"))
lat = xr.DataArray(np.random.uniform(-90, 90, size=shape), dims=("y", "x"))
temperature = xr.DataArray(np.random.uniform(0, 30, size=shape), dims=("y", "x"))

ds_curvilinear = xr.Dataset(
    data_vars={"temperature": temperature}, coords={"lon": lon, "lat": lat}
)
ds_curvilinear

In [None]:
# Apply NDPointIndex - requires BOTH lon and lat together
ds_indexed = ds_curvilinear.set_xindex(["lon", "lat"], xr.indexes.NDPointIndex)
ds_indexed

In [None]:
# Now we can query: "Find the grid cell nearest to lat=45, lon=-120"
# This is a SPATIAL query - both coordinates together define a point
ds_indexed.sel(lat=45.0, lon=-120.0, method="nearest")

## Trying NDPointIndex with trial-based data

Now let's see what happens when we try to use `NDPointIndex` with our trial-based dataset where we have a single 2D `abs_time` coordinate.

In [None]:
ds = trial_based_dataset(mode="stacked").drop_vars("trial_onset")
print(ds)

### Problem 1: NDPointIndex requires matching number of variables and dimensions

`NDPointIndex` expects one coordinate variable per dimension. Our `abs_time` is a single 2D variable, not two 1D variables that define points in 2D space.

In [None]:
# This fails! NDPointIndex expects 2 variables for 2 dimensions
try:
    ds.set_xindex(["abs_time"], xr.indexes.NDPointIndex)
except ValueError as e:
    print(f"ValueError: {e}")

### Why this matters

The fundamental difference is:

| Aspect | NDPointIndex | NDIndex |
|--------|--------------|----------|
| **Coordinates** | Multiple 2D coords that together define position | Single N-D coord with derived values |
| **Query type** | Spatial: "find point at (x, y)" | Value: "find cell where value ≈ target" |
| **Use case** | Curvilinear grids, scattered observations | Structured arrays with computed coordinates |

**NDPointIndex** answers: "Which grid cell is nearest to coordinates (lat=45, lon=-120)?"

**NDIndex** answers: "Which (trial, time) cell has `abs_time` closest to 7.5?"

### Could we reshape the data to use NDPointIndex?

One might try to flatten the data and treat `(trial, rel_time)` as coordinate dimensions for NDPointIndex. Let's see what that looks like:

In [None]:
# Flatten the dataset to 1D
ds_flat = ds.stack(point=("trial", "rel_time"))
print(f"Original shape: {dict(ds.sizes)}")
print(f"Flattened shape: {dict(ds_flat.sizes)}")
ds_flat

In [None]:
# Create separate coordinate arrays for trial index and rel_time
# to use with NDPointIndex
trial_idx = xr.DataArray(np.repeat(np.arange(3), 500), dims=["point"])
rel_time_flat = xr.DataArray(np.tile(ds.rel_time.values, 3), dims=["point"])

ds_for_ndpoint = xr.Dataset(
    data_vars={"data": (["point"], ds_flat.data.values)},
    coords={
        "trial_idx": trial_idx,
        "rel_time_flat": rel_time_flat,
        "abs_time": (["point"], ds_flat.abs_time.values),
    },
)
ds_for_ndpoint

In [None]:
# Now we could apply NDPointIndex with trial_idx and rel_time_flat
ds_ndpoint = ds_for_ndpoint.set_xindex(
    ["trial_idx", "rel_time_flat"], xr.indexes.NDPointIndex
)
ds_ndpoint

In [None]:
# Query: find point nearest to trial_idx=1, rel_time=2.5
result = ds_ndpoint.sel(trial_idx=1, rel_time_flat=2.5, method="nearest")
print(
    f"Found point at trial_idx={result.trial_idx.item()}, rel_time={result.rel_time_flat.item():.2f}"
)
print(f"abs_time at this point: {result.abs_time.item():.2f}")

### But this doesn't solve our problem!

With this approach:
1. **We can't select by `abs_time` directly** - NDPointIndex uses the indexed coordinates (trial_idx, rel_time_flat), not derived values like abs_time
2. **We lose the structured array** - the data is now 1D instead of (trial, rel_time)
3. **We lose trial labels** - trial_idx is numeric, not the original string labels

In [None]:
# We CANNOT do this - abs_time is not an indexed coordinate:
try:
    ds_ndpoint.sel(abs_time=7.5, method="nearest")
except KeyError as e:
    print(f"KeyError: {e}")

### Could we use abs_time with KDTree directly?

Another approach might be to build a KDTree on abs_time values directly. But scipy's KDTree expects points in N-dimensional space, not scalar lookups:

In [None]:
from scipy.spatial import KDTree

# KDTree expects (n_points, n_dims) array
# Our abs_time is shape (3, 500) = 1500 scalar values
# Reshaping to (1500, 1) treats each value as a 1D point
abs_time_flat = ds.abs_time.values.ravel().reshape(-1, 1)
tree = KDTree(abs_time_flat)

# Query for abs_time ≈ 7.5
distance, flat_idx = tree.query([[7.5]])
trial_idx = flat_idx[0] // 500
time_idx = flat_idx[0] % 500

print(
    f"Found: trial={ds.trial.values[trial_idx]}, rel_time={ds.rel_time.values[time_idx]:.2f}"
)
print(f"abs_time at this point: {ds.abs_time.values[trial_idx, time_idx]:.2f}")

This works, but:
1. It's not integrated with xarray's indexing system
2. You have to manually convert between flat indices and (trial, time) indices
3. It doesn't support slices or other advanced indexing
4. The data structure is lost

**This is essentially what `NDIndex` does internally, but with proper xarray integration.**

## Summary

| Feature | NDPointIndex | NDIndex |
|---------|--------------|----------|
| **Use case** | Unstructured point clouds, curvilinear grids | Structured arrays with derived coordinates |
| **Query type** | Spatial: find nearest (x, y) point | Value: find cell where `abs_time ≈ 7.5` |
| **Coordinates** | Multiple N-D coords (one per dimension) | Single N-D coord with computed values |
| **Data structure** | Points in N-D coordinate space | N-D array of scalar values |
| **Returns** | Single nearest point | Dimensional slices |
| **Slice support** | No | Yes (bounding box) |

`NDPointIndex` and `NDIndex` solve different problems:

```python
# NDPointIndex: "Find the grid cell nearest to lat=45.2, lon=-122.5"
ds.sel(lat=45.2, lon=-122.5, method="nearest")  # Spatial query

# NDIndex: "Find which (trial, time) has abs_time closest to 7.5"
ds.sel(abs_time=7.5, method="nearest")  # Value lookup in N-D array
```

Use `NDPointIndex` when your coordinates define positions in space (or similar multi-dimensional coordinate systems).

Use `NDIndex` when you have derived coordinates computed from dimension coordinates (like `abs_time = trial_onset + rel_time`).