In [None]:
import cf_xarray  # noqa: F401
import dask
import geoarrow.rust.core as geoarrow
import lonboard
import numpy as np
import xarray as xr

import grid_indexing
from grid_indexing.distributed import DistributedRTree

xr.set_options(keep_attrs=True);

In [None]:
def _visualize_grid(geoms, color, alpha=0.8, **layer_kwargs):
    from arro3.core import Array, Schema, Table

    colors = {
        "red": [255, 0, 0],
        "green": [0, 255, 0],
        "blue": [0, 0, 255],
        "yellow": [255, 255, 0],
    }

    alpha_ = int((1 - alpha) * 255)
    color_ = colors[color] + [alpha_]

    array = Array.from_arrow(geoms)
    arrays = {"geometry": array, "value": Array.from_numpy(np.arange(len(array)))}
    fields = [array.field.with_name(name) for name, array in arrays.items()]
    schema = Schema(fields)

    table = Table.from_arrays(list(arrays.values()), schema=schema)

    return lonboard.PolygonLayer(
        table=table,
        filled=True,
        get_fill_color=color_,
        get_line_color="black",
        auto_highlight=True,
        wireframe=True,
        **layer_kwargs,
    )


def visualize_result(source_cells, target_cells, result, index):
    def mask_geoarrow(arr, mask):
        shapely_ = geoarrow.to_shapely(arr)
        return geoarrow.from_shapely(shapely_[mask])

    target = target_cells[index]
    mask = result[index, :]
    print("cells found:", np.sum(mask))

    source = mask_geoarrow(source_cells, mask)
    target_cell = _visualize_grid(target, color="blue", alpha=0.8)
    selected = _visualize_grid(source, color="yellow", alpha=0.6)
    source_grid = _visualize_grid(source_cells, color="red", alpha=0.9)

    return lonboard.Map([source_grid, target_cell, selected])

In [None]:
from distributed import Client

client = Client()
client

In [None]:
source_grid = (
    xr.tutorial.open_dataset("air_temperature")
    .assign_coords(lon=lambda ds: (ds["lon"] + 180) % 360 - 180)
    .isel(lon=slice(None, -1))
)
source_geoms = grid_indexing.infer_cell_geometries(source_grid)

In [None]:
min_lon = source_grid["lon"].min().item()
max_lon = source_grid["lon"].max().item()
min_lat = source_grid["lat"].min().item()
max_lat = source_grid["lat"].max().item()

lon = np.linspace(min_lon, max_lon, 100)
lat = np.linspace(min_lat, max_lat, 50)

target_grid = xr.Dataset(
    coords={
        "lon": ("lon", lon, {"standard_name": "longitude"}),
        "lat": ("lat", lat, {"standard_name": "latitude"}),
    }
)
target_geoms = grid_indexing.infer_cell_geometries(target_grid)

### procedure

1. creation of the distributed rtree
   - values:
        - the cell boundaries as geometries
        - and from that, the chunk boundaries
    - create and index from the chunk boundaries and save it
    - for each chunk of cell boundaries, create an index (as a delayed function? needs to pickle for that, though)
2. query the index
    - extract the chunk boundaries from the input
    - query the chunk boundary index to figure out which chunks a target chunk interacts with
    - query the interacting chunk's index
    - assemble the result as a sparse matrix

### issues

- for dask to work, the trees have to be pickle-able
- going from a grid of tasks to a concatenated sparse matrix may be tricky

In [None]:
source_geoms_ = np.reshape(geoarrow.to_shapely(source_geoms), (52, 25))
chunked_source_geoms = dask.array.from_array(source_geoms_, chunks=(13, 5))
chunked_source_geoms

In [None]:
target_geoms_ = np.reshape(geoarrow.to_shapely(target_geoms), (100, 50))
chunked_target_geoms = dask.array.from_array(target_geoms_, chunks=(5, 5))
chunked_target_geoms

In [None]:
dtree = DistributedRTree(chunked_source_geoms)
dtree

In [None]:
result = dtree.query_overlap(chunked_target_geoms)
result

In [None]:
result_ = result.compute().todense().reshape((5000, 1300))

In [None]:
visualize_result(source_geoms, target_geoms, result_, 123)