# CRSIndex experiment

Following https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

Notes:
1. This requires the `pydata/xarray:scipy22` branch.
    - I pushed a commit to that branch so that it prints a warning instead of raising an error when an index already exists.
    - We need a drop_indexes
1. Indexes are created using `Index.from_variables`. This means we have to stick everything in Variable objects. 
    - Can we support passing a CRS object directly using `**kwargs`? Not at the moment
    - this is the data model: Indexes are constructed from a subset of Coordinate variables. all necessary information should be in a coordinate variable
    - This demo uses `spatial_ref`, an arbitrary choice.
    
Potential Extensions:
- We (optionally?) want lat, lon in addition to x, y
- potentially more fancy tree-based indexing instead of the simple Pandas based indexing here.
- What do we do for `newds.sel(x=46670, method="nearest")` : allow a CRSIndex with only `y` and no `x`?
- Rioxarray could assign a new index automatically when reprojecting etc.
- better handling of various CRS options.

Bug reports TODO:
1. `Index.create_variables`. `set_xindex` tries to pass a `variables` kwarg but other methods don't. Semes like a bug.
2. How do we delete an existing index?
3. GroupBy doesn't propagate index.
4. error message with `join='exact'` is very generic.

In [1]:
%load_ext watermark

import cf_xarray  # to show off
import rioxarray
import xarray as xr  # need 2022.10.0
%watermark -iv

# indexes only shown in text repr at the moment
xr.set_options(display_style="text")

ds = xr.open_dataset(
    "/Users/dcherian/python/rioxarray/test/test_data/input/PLANET_SCOPE_3D.nc",
    # decode_coords="all",
    engine="rasterio",
)

ds

cf_xarray: 0.7.4
xarray   : 2022.10.0
rioxarray: 0.12.2



In [2]:
from typing import Any, Hashable

from xarray.core.indexes import Index, PandasIndex, get_indexer_nd
from xarray.core.indexing import IndexSelResult, merge_sel_results


def create_spatial_ref(crs_wkt):
    """Because I don't know what I'm doing"""
    return xr.Variable((), 0, attrs={"crs_wkt": crs_wkt})


class CRSIndex(Index):
    # based off Benoit's RasterIndex in
    # https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

    def __init__(self, variables):
        # TODO: hardcoded variable names

        # assert len(xy_indexes) == 2
        assert "x" in variables
        assert "y" in variables
        assert "spatial_ref" in variables

        # TODO: Instead do whatever the rio accessor is doing.
        # rioxarray.open_dataset is doing
        spatial_ref = variables.pop("spatial_ref")
        self._crs = rioxarray.crs.CRS.from_wkt(spatial_ref.attrs["crs_wkt"])

        # must have two distinct dimensions
        # Assumes x, y for index are never scalar. Is that correct?
        dim = [idx.dim for key, idx in variables.items()]
        assert dim[0] != dim[1]

        self._indexes = variables

    # TODO: what goes in options?
    @classmethod
    def from_variables(cls, variables, options):
        # assert len(variables) == 2

        xy_indexes = {
            k: PandasIndex.from_variables({k: v}, options=options)
            for k, v in variables.items()
            if k in ["x", "y"]
        }
        xy_indexes["spatial_ref"] = variables["spatial_ref"]

        return cls(xy_indexes)

    # TODO: variables=None?
    # set_xindex tries to pass variables; this seems like a bug
    def create_variables(self, variables=None):
        idx_variables = {}

        for index in self._indexes.values():
            idx_variables.update(index.create_variables(variables))

        idx_variables["spatial_ref"] = create_spatial_ref(self.as_wkt)
        return idx_variables

    # TODO: see notes about IndexSelResult
    #    The latter is a small class that stores positional indexers (indices)
    #    and that could also store new variables, new indexes,
    #    names of variables or indexes to drop,
    #    names of dimensions to rename, etc.
    def sel(self, labels, **kwargs):

        # sel needs to only handle keys in labels
        # since it delegates to isel.
        # we handle all entries in ._indexes there
        results = []
        for k, index in self._indexes.items():
            if k in labels:
                # defer to pandas type indexing.
                # This is where we would implement KDTree and friends
                results.append(index.sel({k: labels[k]}, **kwargs))
        return merge_sel_results(results)

    def isel(self, indexers):
        # TODO: check dim names in indexes
        results = {}
        for k, index in self._indexes.items():
            if k in indexers:
                # again possible KDTree / friends here.
                results[k] = index.isel({k: indexers[k]})
            else:
                results[k] = index
        # AGAIN!
        results["spatial_ref"] = create_spatial_ref(self.as_wkt)
        return type(self)(results)

    def __repr__(self):
        string = f"CRSIndex: {self._crs.to_string()}"
        return string

    def equals(self, other):
        result = self._crs is other._crs or (
            self._crs == other._crs
            and self._indexes["x"].equals(other._indexes["x"])
            and self._indexes["y"].equals(other._indexes["y"])
        )
        return result

    def join(self, other, how="inner"):
        if self._crs != other._crs:
            raise ValueError(
                "Cannot align or join objects with different CRS. "
                f"Received {self._crs.name!r} and {other._crs.name!r}"
            )

        new_indexes = {
            k: v.join(other._indexes[k], how=how) for k, v in self._indexes.items()
        }
        # create new spatial_ref here.
        new_indexes["spatial_ref"] = create_spatial_ref(self.as_wkt)
        return type(self)(new_indexes)

    def reindex_like(self, other, method=None, tolerance=None):
        # TODO: different method, tolerance for x, y?
        return {
            k: get_indexer_nd(
                self._indexes[k].index, other._indexes[k].index, method, tolerance
            )
            for k in self._indexes.keys()
        }

    @property
    def as_crs(self):
        return self._crs

    @property
    def as_wkt(self):
        return self._crs.to_wkt()

In [3]:
index = CRSIndex.from_variables(
    {
        "x": ds.cf["projection_x_coordinate"].variable,
        "y": ds.cf["projection_y_coordinate"].variable,
        "spatial_ref": ds["spatial_ref"].variable,
    },
    options={},
)
index

CRSIndex: EPSG:32722

🎉

## Assign CRSIndex to a new dataset

First drop the existing default PandasIndex along x, y

In [4]:
newds = ds.drop_indexes(["x", "y"])
newds

Now set the new CRSIndex; note new entry under *Indexes*

In [5]:
names = ds.cf.standard_names
newds = newds.set_xindex(
    (
        *names["projection_x_coordinate"],
        *names["projection_y_coordinate"],
        "spatial_ref",
    ),
    CRSIndex,
)
newds

# Seems like we should delete spatial_ref
# But that is not allowed by set_xindex
# del newds["spatial_ref"] # doesn't work

In [6]:
with xr.set_options(display_style="html"):
    display(newds)

^ `spatial_ref` is now bolded in the HTML repr so it is a indexed variable even though its not associated with a dimensions.

## Selection

### Vectors

In [7]:
#!!!!
newds.sel(x=[46670, 46675], method="nearest")

### TODO: Scalar selection

This fails at the moment.

What is the expected behaviour here?

In [8]:
# Should scalar be supported?
newds.sel(x=46670, method="nearest")

AttributeError: 'NoneType' object has no attribute 'dim'

## Reduction

In [9]:
newds.mean()

We lose `x,y` that makes sense. `spatial_ref` is propagated, so that's great.
- note: Bolded `spatial_ref` is confusing. A scalar index? 

In [10]:
with xr.set_options(display_style="html"):
    display(newds.mean())

## groupby with flox needs to propagate indexes.

flox is worse. Not surprising.

In [11]:
with xr.set_options(use_flox=True):
    result = newds.groupby("time.month").mean()
result

In [12]:
with xr.set_options(use_flox=False):
    result = newds.groupby("time.month").mean()
result

## alignment

### Create a reprojected dataset that should not align

In [13]:
# oops lost index
reprojected = newds.rio.reproject("EPSG:4326")
reprojected

In [14]:
# set CRSIndex again; RIO could do this automatically
reprojected = (
    ds.rio.reproject("EPSG:4326")
    .drop_indexes(["x", "y"])
    .set_xindex(("x", "y", "spatial_ref"), CRSIndex)
)
reprojected

### default join="outer"

Note nice error message!

In [15]:
xr.align(reprojected, newds)

ValueError: Cannot align or join objects with different CRS. Received 'WGS 84' and 'WGS 84 / UTM zone 22S'

### join="exact" test out __equals__

Could have nicer error message

In [16]:
xr.align(reprojected, newds, join="exact")

ValueError: cannot align objects with join='exact' where index/labels/sizes are not equal along these coordinates (dimensions): 'x' ('x',), 'y' ('y',), 'spatial_ref' ()

### BROKEN: Successfully align with a copy of itself


In [17]:
copy = newds.copy(deep=True)
copy

In [19]:
xr.align(copy, newds)

AttributeError: 'PandasIndex' object has no attribute 'index'

### Align with subsets


In [20]:
xr.align(newds.isel(x=[5, 6]), newds, join="outer")

(<xarray.Dataset>
 Dimensions:      (time: 2, x: 10, y: 10)
 Coordinates:
   * time         (time) object 2016-12-19 10:27:29.687763 2016-12-29 12:52:42...
   * x            (x) float64 4.663e+05 4.663e+05 ... 4.663e+05 4.663e+05
   * y            (y) float64 8.085e+06 8.085e+06 ... 8.085e+06 8.085e+06
   * spatial_ref  int64 0
 Data variables:
     blue         (time, y, x) float64 nan nan nan nan nan ... 1.888 nan nan nan
     green        (time, y, x) float64 nan nan nan nan nan ... 12.96 nan nan nan
 Indexes:
     x            CRSIndex: EPSG:32722
     y            CRSIndex: EPSG:32722
     spatial_ref  CRSIndex: EPSG:32722
 Attributes:
     coordinates:  spatial_ref,
 <xarray.Dataset>
 Dimensions:      (time: 2, x: 10, y: 10)
 Coordinates:
   * time         (time) object 2016-12-19 10:27:29.687763 2016-12-29 12:52:42...
   * x            (x) float64 4.663e+05 4.663e+05 ... 4.663e+05 4.663e+05
   * y            (y) float64 8.085e+06 8.085e+06 ... 8.085e+06 8.085e+06
   * spatial_re

In [21]:
xr.align(newds.isel(x=[5, 6]), newds, join="inner")

(<xarray.Dataset>
 Dimensions:      (time: 2, x: 2, y: 10)
 Coordinates:
   * time         (time) object 2016-12-19 10:27:29.687763 2016-12-29 12:52:42...
   * x            (x) float64 4.663e+05 4.663e+05
   * y            (y) float64 8.085e+06 8.085e+06 ... 8.085e+06 8.085e+06
   * spatial_ref  int64 0
 Data variables:
     blue         (time, y, x) float64 4.76 5.078 3.228 ... 3.218 4.871 1.888
     green        (time, y, x) float64 11.54 14.43 44.91 ... 50.6 48.73 12.96
 Indexes:
     x            CRSIndex: EPSG:32722
     y            CRSIndex: EPSG:32722
     spatial_ref  CRSIndex: EPSG:32722
 Attributes:
     coordinates:  spatial_ref,
 <xarray.Dataset>
 Dimensions:      (time: 2, x: 2, y: 10)
 Coordinates:
   * time         (time) object 2016-12-19 10:27:29.687763 2016-12-29 12:52:42...
   * x            (x) float64 4.663e+05 4.663e+05
   * y            (y) float64 8.085e+06 8.085e+06 ... 8.085e+06 8.085e+06
   * spatial_ref  int64 0
 Data variables:
     blue         (time, y, 