In [None]:
import cf_xarray  # noqa: F401
import dask.array
import geoarrow.rust.core as geoarrow
import numpy as np
import shapely
import xarray as xr

import grid_indexing

In [None]:
from distributed import Client

client = Client()
client

In [None]:
def bounds_to_polygons(lon, lat):
    vertices = xr.concat([lon, lat], dim="coords")

    return xr.apply_ufunc(
        shapely.polygons,
        vertices.chunk({"coords": -1}),
        input_core_dims=[["bounds", "coords"]],
        output_core_dims=[[]],
        dask="parallelized",
        keep_attrs=False,
        output_dtypes=[object],
    )

In [None]:
source_grid = (
    grid_indexing.tutorial.generate_grid("2d-curvilinear", resolution="large")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"x": 600, "y": 300})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
source_grid

### procedure

1. creation of the distributed rtree
   - values:
        - the cell boundaries as geometries
        - and from that, the chunk boundaries
    - create and index from the chunk boundaries and save it
    - for each chunk of cell boundaries, create an index (as a delayed function? needs to pickle for that, though)
2. query the index
    - extract the chunk boundaries from the input
    - query the chunk boundary index to figure out which chunks a target chunk interacts with
    - query the interacting chunk's index
    - assemble the result as a sparse matrix

### issues

- for dask to work, the trees have to be pickle-able
- going from a grid of tasks to a concatenated sparse matrix may be tricky

In [None]:
target_grid = (
    grid_indexing.tutorial.generate_grid("2d-rectilinear", "small")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"y": 30, "x": 60})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
target_grid

In [None]:
def chunk_boundaries(chunks):
    coverage_ = dask.delayed(shapely.coverage_union_all)

    return list(map(coverage_, chunks))


def index_from_shapely(chunk):
    return grid_indexing.Index(geoarrow.from_shapely(chunk.flatten()))


def _query_overlap(index, chunk):
    result = index.query_overlap(geoarrow.from_shapely(chunk.flatten()))

    if result.nnz == 0:
        return None

    return result


class DistributedRTree:
    def __init__(self, grid):
        chunk_grid = grid["geometry"].data.to_delayed()

        self.chunk_grid_shape = chunk_grid.shape

        self.chunks = chunk_grid.flatten()
        [boundaries] = dask.compute(chunk_boundaries(self.chunks))

        self.chunk_indexes = list(map(dask.delayed(index_from_shapely), self.chunks))
        self.index = grid_indexing.Index.from_shapely(np.array(boundaries))

    def query_overlap(self, grid):
        chunk_grid = grid["geometry"].data.to_delayed()
        chunks = chunk_grid.flatten()

        # query overlapping indices
        [boundaries] = dask.compute(chunk_boundaries(chunks))
        geoms = geoarrow.from_shapely(np.array(boundaries))
        overlapping_chunks = self.index.query_overlap(geoms).todense()

        # actual distributed query
        tasks = np.full_like(overlapping_chunks, dtype=object, fill_value=None)
        for target_index, chunk in enumerate(chunks):
            [source_indices] = np.nonzero(overlapping_chunks[target_index])

            tasks[target_index, source_indices] = np.array(
                [
                    dask.delayed(_query_overlap)(
                        self.chunk_indexes[source_index], chunks[target_index]
                    )
                    for source_index in source_indices
                ]
            )

        return tasks


dtree = DistributedRTree(source_grid)
dtree

In [None]:
result = dtree.query_overlap(target_grid)
result

In [None]:
%%time
dask.compute(result.tolist())

In [None]:
%%time
geoms = grid_indexing.infer_cell_geometries(source_grid)
geoms

In [None]:
%%time
index = grid_indexing.Index(geoms)
index

In [None]:
%%time
target_geoms = grid_indexing.infer_cell_geometries(target_grid)
target_geoms

In [None]:
%%time
overlaps = index.query_overlap(target_geoms)
overlaps