In [None]:
from dataclasses import dataclass

import cf_xarray  # noqa: F401
import dask.array
import geoarrow.rust.core as geoarrow
import numpy as np
import shapely
import sparse
import xarray as xr

import grid_indexing

In [None]:
from distributed import Client

client = Client()
client

In [None]:
def bounds_to_polygons(lon, lat):
    vertices = xr.concat([lon, lat], dim="coords")

    return xr.apply_ufunc(
        shapely.polygons,
        vertices.chunk({"coords": -1}),
        input_core_dims=[["bounds", "coords"]],
        output_core_dims=[[]],
        dask="parallelized",
        keep_attrs=False,
        output_dtypes=[object],
    )

In [None]:
source_grid = (
    grid_indexing.tutorial.generate_grid("2d-curvilinear", resolution="small")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"x": 60, "y": 30})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
source_grid

### procedure

1. creation of the distributed rtree
   - values:
        - the cell boundaries as geometries
        - and from that, the chunk boundaries
    - create and index from the chunk boundaries and save it
    - for each chunk of cell boundaries, create an index (as a delayed function? needs to pickle for that, though)
2. query the index
    - extract the chunk boundaries from the input
    - query the chunk boundary index to figure out which chunks a target chunk interacts with
    - query the interacting chunk's index
    - assemble the result as a sparse matrix

### issues

- for dask to work, the trees have to be pickle-able
- going from a grid of tasks to a concatenated sparse matrix may be tricky

In [None]:
target_grid = (
    grid_indexing.tutorial.generate_grid("2d-rectilinear", "small")
    .cf.add_bounds(["latitude", "longitude"])
    .chunk({"y": 30, "x": 60})
    .assign_coords(
        geometry=lambda ds: bounds_to_polygons(ds["lon_bounds"], ds["lat_bounds"])
    )
)
target_grid

In [None]:
@dataclass
class ChunkGrid:
    shape: tuple
    chunks: np.ndarray

    @classmethod
    def from_dask(cls, arr):
        shape = arr.shape
        chunksizes = arr.chunks

        grid = np.stack(np.meshgrid(*chunksizes), axis=-1)

        return cls(shape, grid)

    @property
    def grid_shape(self):
        return self.chunks.shape[:-1]

    def __repr__(self):
        name = type(self).__name__

        return f"{name}(shape={self.shape}, chunks={np.prod(self.grid_shape)})"

    def chunk_size(self, flattened_index):
        indices = np.unravel_index(flattened_index, self.chunks.shape[:-1])

        return np.prod(self.chunks[*indices, :])

In [None]:
def chunk_boundaries(chunks):
    coverage_ = dask.delayed(shapely.unary_union)

    return list(map(coverage_, chunks))


def index_from_shapely(chunk):
    return grid_indexing.Index(geoarrow.from_shapely(chunk.flatten()))


def _query_overlap(index, chunk, shape):
    result = index.query_overlap(geoarrow.from_shapely(chunk.flatten()))

    if result.nnz == 0:
        return _empty_chunk(index, chunk, shape)

    return result


def _empty_chunk(index, chunk, shape):
    arr = sparse.full(shape=shape, fill_value=False, dtype=bool)
    return sparse.GCXS.from_coo(arr)


class DistributedRTree:
    def __init__(self, grid):
        geoms = grid["geometry"].data
        self.chunk_grid = ChunkGrid.from_dask(geoms)

        chunk_grid = geoms.to_delayed()

        self.chunks = chunk_grid.flatten()
        [boundaries] = dask.compute(chunk_boundaries(self.chunks))

        self.chunk_indexes = list(map(dask.delayed(index_from_shapely), self.chunks))
        self.index = grid_indexing.Index.from_shapely(np.array(boundaries))

    def query_overlap(self, grid):
        geoms = grid["geometry"].data
        chunk_grid = ChunkGrid.from_dask(geoms)
        input_chunks = geoms.to_delayed().flatten()

        # query overlapping indices
        [boundaries] = dask.compute(chunk_boundaries(input_chunks))
        geoms = geoarrow.from_shapely(np.array(boundaries))
        overlapping_chunks = self.index.query_overlap(geoms).todense()

        # actual distributed query
        chunks = np.full_like(overlapping_chunks, dtype=object, fill_value=None)
        meta = sparse.GCXS.from_coo(sparse.empty((), dtype=bool))

        for target_index, input_chunk in enumerate(input_chunks):
            for source_index, mask in enumerate(overlapping_chunks[target_index]):
                func = _query_overlap if mask else _empty_chunk
                shape = (
                    chunk_grid.chunk_size(target_index),
                    self.chunk_grid.chunk_size(source_index),
                )

                task = dask.delayed(func)(
                    self.chunk_indexes[source_index],
                    input_chunk,
                    shape=shape,
                )
                chunk = dask.array.from_delayed(
                    task, shape=shape, dtype=bool, meta=meta
                )

                chunks[target_index, source_index] = chunk

        return dask.array.concatenate(
            [
                dask.array.concatenate(chunks[row, :].tolist(), axis=1)
                for row in range(chunks.shape[0])
            ],
            axis=0,
        )


dtree = DistributedRTree(source_grid)
dtree

In [None]:
result = dtree.query_overlap(target_grid)
type(result)

In [None]:
%%time
dask.compute(result)[0]