# XGCM Grid Index

In [None]:
%load_ext watermark

import xarray as xr
import xgcm
from xarray.testing import _assert_internal_invariants

%watermark -iv

In [None]:
import xgcm.test.datasets

ds, coords, metrics = xgcm.test.datasets.datasets_grid_metric("C")

In [None]:
import numpy as np
from xarray.core.indexes import Index, PandasIndex
from xarray.core.indexing import IndexSelResult


def grid_isel_indexers(grid: xgcm.Grid, indexers: dict):
    # TODO: make this a Grid.isel method
    # TODO: cache positions on the grid object.
    dim_to_pos = {}
    for key, axis in grid.axes.items():
        for pos, name in axis.coords.items():
            dim_to_pos.update({name: (key, pos)})

    new_indexers = {}
    new_indexers.update(indexers)
    for dim, key in indexers.items():
        axis, pos = dim_to_pos[dim]
        print(axis, pos)

        if np.isscalar(key):
            key = [key]
        key = np.asarray(key)

        for pos_, dim in grid.axes[axis].coords.items():
            if dim in new_indexers:
                continue

            if len(key) > 1:
                raise NotImplementedError
            if pos_ == pos:
                continue
            if pos_ == "right":
                new_indexers[dim] = np.insert(key, 0, key - 1)
            elif pos_ == "left":
                new_indexers[dim] = np.append(key, key + 1)
    print(new_indexers)
    return new_indexers


def get_grid_var_names(grid):
    import itertools

    return tuple(
        itertools.chain(*[tuple(axis.coords.values()) for axis in grid.axes.values()])
    ) + tuple(
        itertools.chain(*[tuple(v.name for v in metrics) for metrics in grid._metrics.values()])
    )


class XgcmGridIndex(Index):
    # based off Benoit's RasterIndex in
    # https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

    def __init__(self, grid, indexes):
        display("Creating new index... __init__")
        self.grid = grid

        # all variable names
        self.index_var_names = get_grid_var_names(grid)
        self._indexes = indexes

    # TODO: what goes in options?
    @classmethod
    def from_variables(cls, variables, options):
        grid = options.pop("grid")

        coord_names = []
        for _, axis in grid.axes.items():
            coord_names.extend(axis.coords.values())

        indexes = {
            key: PandasIndex.from_variables({key: variables[key]}, options=options)
            for key in coord_names
        }
        return cls(grid, indexes)

    # TODO: variables=None?
    # set_xindex tries to pass variables; this seems like a bug
    def create_variables(self, variables=None):
        return {name: self.grid._ds._variables[name] for name in self.index_var_names}

    # TODO: see notes about IndexSelResult
    #    The latter is a small class that stores positional indexers (indices)
    #    and that could also store new variables, new indexes,
    #    names of variables or indexes to drop,
    #    names of dimensions to rename, etc.
    def sel(self, labels, **kwargs):
        # sel needs to only handle keys in labels
        # since it delegates to isel.
        # we handle all entries in ._indexes there
        results = {}

        # convert provided label indexers to positional indexers
        for name, keys in labels.items():
            pdindex = self._indexes[name].index
            idxs = pdindex.get_indexer(keys)
            results[name] = idxs

        # bring in linked dimensions
        results = grid_isel_indexers(self.grid, results)
        return IndexSelResult(dim_indexers=results)

    def isel(self, indexers):
        indexers = grid_isel_indexers(self.grid, indexers)
        display(indexers)

        # TODO: check dim names in indexes
        results = {}
        for k, index in self._indexes.items():
            if k in indexers:
                results[k] = index.isel({k: indexers[k]})
            else:
                results[k] = index
        # display(results)
        subset_ds = self.grid._ds.isel(indexers)
        # display(subset_ds)
        new_grid = xgcm.Grid(
            subset_ds,
            coords=self.grid._saved_coords,
            metrics=self.grid._saved_metrics,
        )
        return type(self)(new_grid, results)

    def __repr__(self):
        string = "XGCM/Index"
        return string


import itertools

ds, coords, metrics = xgcm.test.datasets.datasets_grid_metric("C")
grid = xgcm.Grid(ds, coords=coords, metrics=metrics)

# monkey patch attributes to make it easier to recreate a Grid
grid._saved_coords = coords
grid._saved_metrics = metrics

grid_var_names = tuple(
    itertools.chain(*[tuple(axis.coords.values()) for axis in grid.axes.values()])
) + tuple(itertools.chain(*[tuple(v.name for v in metrics) for metrics in grid._metrics.values()]))
dim_names = tuple(itertools.chain(*[tuple(axis.coords.values()) for axis in grid.axes.values()]))
# spatial_dims = [dim for dim in ds.dims if dim != "time"]
# ds = ds.drop_indexes(spatial_dims)
newds = (
    # Set grid variables as coords
    ds.set_coords(grid_var_names)
    # Need to drop existing indexed dims; somewhat annoying
    .drop_indexes(dim_names)
    .set_xindex(grid_var_names, index_cls=XgcmGridIndex, grid=grid)
)
newds

## Test index propagation

In [None]:
newds.mean("time")

## Subset with .sel

In [None]:
newds.sel(xt=[2])

In [None]:
newds.xgcm.interp(axis="X")

Not sure why this is failing. Seems like I need to make a bunch of variables IndexVariables? That seems unnecessary.

In [None]:
_assert_internal_invariants(newds.sel(xt=[2]), check_default_indexes=False)

## Subset with .isel

This isn't working, `indexers={"xt": [2]}` needs to be updated before subsetting variables but we have no mechanism to allow that. We'd need `IndexISelResult` for example? Or convert from positional to label, then back to positional is quite silly.

In [None]:
_assert_internal_invariants(newds.mean("time").isel(xt=[2]), False)

## Accessor

bah doesn't work because we don't propagate index with DataArray extraction, and xgcm can only operate on dataarrays.

In [None]:
@xr.register_dataset_accessor("xgcm")
class XgcmAccessor:
    def __init__(self, obj):
        self._obj = obj

        # just pick the first XGCM Grid Index
        for index in newds.xindexes.values():
            if isinstance(index, XgcmGridIndex):
                break
        self.grid = index.grid

    def diff(self, name, *args, **kwargs):
        result = self.grid.diff(self._obj[name], *args, **kwargs)
        return result


newds.xgcm.diff("tracer", axis="X")