In [None]:
import cf_xarray  # noqa: F401
import dask
import shapely
import xarray as xr

import grid_indexing
from grid_indexing.distributed import DistributedRTree

In [None]:
from distributed import Client

client = Client()
client

In [None]:
def bounds_to_polygons(lon, lat):
    vertices = xr.concat([lon, lat], dim="coords")

    return xr.apply_ufunc(
        shapely.polygons,
        vertices.chunk({"coords": -1}),
        input_core_dims=[["bounds", "coords"]],
        output_core_dims=[[]],
        dask="parallelized",
        keep_attrs=False,
        output_dtypes=[object],
    )

In [None]:
source_grid = (
    grid_indexing.tutorial.generate_grid("2d-curvilinear", resolution="small")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"x": 60, "y": 30})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
source_grid

### procedure

1. creation of the distributed rtree
   - values:
        - the cell boundaries as geometries
        - and from that, the chunk boundaries
    - create and index from the chunk boundaries and save it
    - for each chunk of cell boundaries, create an index (as a delayed function? needs to pickle for that, though)
2. query the index
    - extract the chunk boundaries from the input
    - query the chunk boundary index to figure out which chunks a target chunk interacts with
    - query the interacting chunk's index
    - assemble the result as a sparse matrix

### issues

- for dask to work, the trees have to be pickle-able
- going from a grid of tasks to a concatenated sparse matrix may be tricky

In [None]:
target_grid = (
    grid_indexing.tutorial.generate_grid("2d-rectilinear", "small")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"y": 30, "x": 60})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
target_grid

In [None]:
dtree = DistributedRTree(source_grid)
dtree

In [None]:
result = dtree.query_overlap(target_grid)
result

In [None]:
%%time
dask.compute(result)[0]