# IntervalIndex experiment

In [5]:
%load_ext watermark

import numpy as np
import pandas as pd
import xarray as xr

%watermark -iv

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
pandas: 2.0.3
xarray: 2023.6.0
numpy : 1.24.4



## TODO

- try structured dtype?

## Example data

In [97]:
import numpy as np

left = np.arange(0.5, 3.6, 1)
right = np.arange(1.5, 4.6, 1)
bounds = np.stack([left, right])

ds = xr.Dataset(
    {"data": ("x", [5, 6, 7, 8])},
    coords={
        "x": ("x", [1, 2, 3, 4], {"bounds": "x_bounds"}),
        "x_bounds": (("bnds", "x"), bounds),
    },
)
ds

## Approach 1. Dropping bounds variable

### From [benbovy](https://github.com/pydata/xarray/discussions/7041#discussioncomment-4936891)

I think one of the best approaches would be to have a custom Xarray IntervalIndex that only supports a unique 1-d coordinate.

That said, an extension array adaptor like the one at the end of @dcherian's notebook would be nice to have. As another example, geopandas uses extension arrays for the geometry columns so this could be useful for other cases like martinfleis/xvec.

Perhaps something like this could work for IntervalIndex:

1. Expose a custom xarray.indexes.IntervalIndex that wraps a pd.IntervalIndex
    - Like for PandasIndex, its corresponding 1-d coordinate wraps the pandas index in a xarray.core.indexing.PandasIndexingAdapter
    - It could actually be implemented as a subclass of PandasIndex?
2. How to use set_xindex?
    - obj.set_xindex("x", IntervalIndex) where "x" is a 1-d coordinate (if it wraps an pd.arrays.IntervalArray or a pd.IntervalIndex just reuse it, otherwise create a new pandas interval index with default bounds).
    - obj.set_xindex("x", IntervalIndex, use_bounds_coord="x_bounds") where use_bounds_coord is an option specific to the Xarray IntervalIndex. In this case the "x_bounds" coordinate remains unindexed, it could be manually removed after creating the index.
2. Getting back the bounds coordinate, e.g., with something like obj.assign(x_bounds=obj.xindexes["x"].get_bounds_coord())

#### TODO

1. pd.IntervalIndex cannot support alternative "central" values. We'll need a more involved approach to do that.
2. We'll need a "decoding" function to break this in to two variables prior to writing to disk

In [98]:
from xarray import Variable
from xarray.indexes import PandasIndex


class XarrayIntervalIndex(PandasIndex):
    def __init__(self, index, dim, coord_dtype):
        assert isinstance(index, pd.IntervalIndex)

        # for PandasIndex
        self.index = index
        self.dim = dim
        self.coord_dtype = coord_dtype

    @classmethod
    def from_variables(cls, variables, options):
        assert len(variables) == 1
        (dim,) = tuple(variables)
        bounds = options["bounds"]
        assert isinstance(bounds, (xr.DataArray, xr.Variable))

        (axis,) = bounds.get_axis_num(set(bounds.dims) - {dim})
        left, right = np.split(bounds.data, 2, axis=axis)
        index = pd.IntervalIndex.from_arrays(left.squeeze(), right.squeeze())
        coord_dtype = bounds.dtype

        return cls(index, dim, coord_dtype)

    def create_variables(self, variables):
        from xarray.core.indexing import PandasIndexingAdapter

        newvars = {self.dim: xr.Variable(self.dim, PandasIndexingAdapter(self.index))}
        return newvars

    def __repr__(self):
        string = f"Xarray{self.index!r}"
        return string

    def to_pandas_index(self):
        return self.index

    @property
    def mid(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)

    @property
    def left(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)

    @property
    def right(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)


ds1 = (
    ds.drop_indexes("x")
    .set_xindex("x", XarrayIntervalIndex, bounds=ds.x_bounds)
    .drop_vars("x_bounds")
)
ds1

In [93]:
ds1.data.sel(x=1.1)

## Approach 2. Create new PandasIntervalIndex

instead of using `PandasIndex` wrapping `pd.IntervalIndex`

In [146]:
class BoundsArrayWrapper:
    def __init__(self, array: np.ndarray, axis: int):
        assert bounds.ndim == 2
        assert bounds.shape[axis] == 2
        self.axis = axis
        self.array = array

    def __array__(self):
        return self.array

    def get_duck_array(self):
        return self.array

    def values(self):
        return self.array

    @property
    def shape(self):
        return shape

    @property
    def ndim(self):
        return self.array.ndim - 1

    @property
    def shape(self):
        return tuple(s for ax, s in enumerate(self.array.shape) if ax != self.axis)

    @property
    def data(self):
        raise
        return self.array

    @property
    def dtype(self):
        return self.array.dtype

    def __array_ufunc__(self, *args, **kwargs):
        return self.array.__array_ufunc__(*args, **kwargs)

    def __array_function__(self, func, types, args, kwargs):
        raise NotImplementedError

    def __repr__(self):
        return f"BoundsArray{repr(self.array)[5:]}"

    def __getitem__(self, key):
        print(key)
        if len(key) != 2 or key[1] != ...:
            raise
        if isinstance(key, tuple):
            label, ell = key
            if ell != Ellipsis:
                ell, label = label, ell
        else:
            label = key

        newkey = [label, label]
        newkey[self.axis] = Ellipsis

        # normalize key depending on axis
        return type(self)(self.array[tuple(newkey)], axis=self.axis)


@xr.register_dataarray_accessor("bounds")
class BoundsAccessor:
    def __init__(self, da):
        self.da = da

    def wrap(self):
        return


wrapped = BoundsArrayWrapper(bounds, 0)
wrapped

BoundsArray([[0.5, 1.5, 2.5, 3.5],
       [1.5, 2.5, 3.5, 4.5]])

In [138]:
wrapped[([1, 1], ...)]

([1, 1], Ellipsis)


BoundsArray([[1.5, 1.5],
       [2.5, 2.5]])

In [192]:
from typing import Any, Hashable

from xarray.core.indexes import Index, PandasIndex, get_indexer_nd
from xarray.core.indexing import IndexSelResult, merge_sel_results
from xarray.core.variable import Variable


class XarrayIntervalIndex(PandasIndex):
    # based off Benoit's RasterIndex in
    # https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

    def __init__(self, index, dim, coord_dtype):
        # TODO: hardcoded variable names

        assert isinstance(index, pd.IntervalIndex)

        # for PandasIndex
        self.index = index
        self.dim = dim
        self.coord_dtype = coord_dtype

    @classmethod
    def from_variables(cls, variables, options):
        assert len(variables) == 2

        for k, v in variables.items():
            if isinstance(v.data, BoundsArrayWrapper):
                bounds_name, bounds = k, v
            elif v.ndim == 1:
                dim, other = k, v

        axis = bounds.data.axis
        other_axis = 1 if axis == 0 else 1
        left, right = np.split(bounds.data.array, 2, axis=axis)
        bounds = bounds.data
        index = pd.IntervalIndex.from_arrays(left.squeeze(), right.squeeze())

        coord_dtype = bounds.array.dtype

        return cls(index, dim, coord_dtype)

    def create_variables(self, variables):
        bounds_array = BoundsArrayWrapper(
            np.stack([self.index.left, self.index.right]), axis=0
        )
        bounds = Variable(dims=self.dim, data=bounds_array)
        mid = Variable(dims=self.dim, data=self.index.mid)

        newvars = {}
        for k, v in variables.items():
            if isinstance(v.data, BoundsArrayWrapper):
                newvars[k] = bounds
            else:
                newvars[k] = mid
        return newvars

    def __repr__(self):
        string = f"Xarray{self.index!r}"
        return string

    def to_pandas_index(self):
        return self.index

    @property
    def mid(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)

    @property
    def left(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)

    @property
    def right(self):
        return PandasIndex(self.index.right, self.dim, self.coord_dtype)


ds = xr.Dataset(
    {"data": (("time", "x"), np.arange(20).reshape(5, 4))},
    coords={"x": [1, 2, 3, 4], "x_bounds": (("bnds", "x"), bounds)},
)
tindex = pd.date_range("2001-01-01", "2001-01-05", freq="D")
left = tindex - pd.DateOffset(hours=12)
right = tindex + pd.DateOffset(hours=12)
tbounds = np.stack([left, right])
ds.coords["time"] = ("time", tindex)
ds.coords["time_bounds"] = ("time", BoundsArrayWrapper(tbounds, axis=0))
wrapped = BoundsArrayWrapper(ds.x_bounds.data, axis=ds.x_bounds.get_axis_num("bnds"))
ds.update({"x_bounds": ("x", wrapped)})
print(ds)

newds = ds.drop_indexes("x").set_xindex(
    ("x", "x_bounds"),
    XarrayIntervalIndex,
)
newds = newds.drop_indexes("time").set_xindex(
    ("time", "time_bounds"),
    XarrayIntervalIndex,
)
newds

<xarray.Dataset>
Dimensions:      (time: 5, x: 4)
Coordinates:
  * x            (x) int64 1 2 3 4
    x_bounds     (x) float64 BoundsArray([[0.5, 1.5, 2.5, 3.5],        [1.5, ...
  * time         (time) datetime64[ns] 2001-01-01 2001-01-02 ... 2001-01-05
    time_bounds  (time) datetime64[ns] BoundsArray([['2000-12-31T12:00:00.000...
Data variables:
    data         (time, x) int64 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19


In [198]:
newds.sel(x=[1.1, 2.4], time=["2001-01-02 13:00"]).xindexes

Indexes:
    x            XarrayIntervalIndex([(0.5, 1.5], (1.5, 2.5]], dtype='interval[float64, right]')
    x_bounds     XarrayIntervalIndex([(0.5, 1.5], (1.5, 2.5]], dtype='interval[float64, right]')
    time         XarrayIntervalIndex([(2001-01-02 12:00:00, 2001-01-03 12:00:00]], dtype='interval[datetime64[ns], right]')
    time_bounds  XarrayIntervalIndex([(2001-01-02 12:00:00, 2001-01-03 12:00:00]], dtype='interval[datetime64[ns], right]')

## Approach 3. PandasMetaIndex style approach

### Splitting into 2 arrays

doesn't work with `set_xindex` which does not expect changes in variables in vs variables out

In [2]:
left = np.arange(0.5, 3.6, 1)
right = np.arange(1.5, 4.6, 1)
bounds = np.stack([left, right])
bounds.shape

ds = xr.Dataset(
    {"data": ("x", [1, 2, 3, 4])},
    coords={"x": [1, 2, 3, 4], "x_bounds": (("bnds", "x"), bounds)},
)
ds

In [19]:
iidx = pd.IntervalIndex.from_arrays([0, 1, 2], [1, 2, 3], closed="left")

ds = xr.Dataset(coords={"x": iidx.values})

actual = pd.IntervalIndex(ds.x.variable.data)

assert actual.equals(iidx)
ds.indexes["x"]

Index([[0, 1), [1, 2), [2, 3)], dtype='object', name='x')

In [17]:
iidx

IntervalIndex([[0, 1), [1, 2), [2, 3)], dtype='interval[int64, left]')

In [7]:
from typing import Any, Hashable

from xarray.core.indexes import Index, PandasIndex, get_indexer_nd
from xarray.core.indexing import IndexSelResult, merge_sel_results
from xarray.core.variable import Variable


class XarrayIntervalIndex(Index):
    # based off Benoit's RasterIndex in
    # https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

    def __init__(self, variables):
        print("init")

        assert len(variables) == 2
        self._variables = variables

        for k, v in variables.items():
            if v.ndim == 2:
                self._bounds_name, bounds = k, v
            elif v.ndim == 1:
                dim, other = k, v

        bounds = bounds.transpose(..., dim)
        left, right = bounds.data.tolist()
        self._index = pd.IntervalIndex.from_arrays(left, right)
        self._dim = dim
        self._bounds_dim = (set(bounds.dims) - set(dim)).pop()

    @classmethod
    def from_variables(cls, variables, options):
        print("in from_variables")
        assert len(variables) == 2
        return cls(variables)

    # TODO: variables=None?
    # set_xindex tries to pass variables; this seems like a bug
    def create_variables(self, variables=None):
        print("in create_vars")
        print(variables)
        variables = {
            f"{self._dim}{edge}": Variable(
                dims=(self._dim,), data=getattr(self._index, edge)
            )
            for suffix, edge in [("left", "left"), ("right", "right"), ("", "mid")]
        }
        return variables

    # TODO: see notes about IndexSelResult
    #    The latter is a small class that stores positional indexers (indices)
    #    and that could also store new variables, new indexes,
    #    names of variables or indexes to drop,
    #    names of dimensions to rename, etc.
    def sel(self, labels, **kwargs):
        # sel needs to only handle keys in labels
        # since it delegates to isel.
        # we handle all entries in ._indexes there
        results = self.index.sel({k: labels[k]}, **kwargs)
        return merge_sel_results(results)

    def isel(self, indexers):
        # TODO: check dim names in indexes
        results = {}
        for k, index in self._indexes.items():
            if k in indexers:
                # again possible KDTree / friends here.
                results[k] = index.isel({k: indexers[k]})
            else:
                results[k] = index
        return type(self)(results)

    def __repr__(self):
        string = f"Xarray{self._index!r}"
        return string


newds = ds.drop_indexes("x").set_xindex(
    ("x", "x_bounds"),
    XarrayIntervalIndex,
)
newds

in from_variables
init
in create_vars
{'x': <xarray.Variable (x: 4)>
array([1, 2, 3, 4]), 'x_bounds': <xarray.Variable (bnds: 2, x: 4)>
array([[0.5, 1.5, 2.5, 3.5],
       [1.5, 2.5, 3.5, 4.5]])}


### Preserving 2D bounds variable

Doesn't work because it isn't propagated with DataArray selection

In [4]:
left = np.arange(0.5, 3.6, 1)
right = np.arange(1.5, 4.6, 1)
bounds = np.stack([left, right])
bounds.shape

ds = xr.Dataset(
    {"data": ("x", [1, 2, 3, 4])},
    coords={"x": [1, 2, 3, 4], "x_bounds": (("bnds", "x"), bounds)},
)

In [78]:
from typing import Any, Hashable

from xarray.core.indexes import Index, PandasIndex, get_indexer_nd
from xarray.core.indexing import IndexSelResult, merge_sel_results
from xarray.core.variable import Variable


class XarrayIntervalIndex(Index):
    # based off Benoit's RasterIndex in
    # https://hackmd.io/Zxw_zCa7Rbynx_iJu6Y3LA?view

    def __init__(self, variables):
        # TODO: hardcoded variable names

        assert len(variables) == 2
        self._variables = variables

        for k, v in variables.items():
            if v.ndim == 2:
                self._bounds_name, bounds = k, v
            elif v.ndim == 1:
                dim, other = k, v

        bounds = bounds.transpose(..., dim)
        left, right = bounds.data.tolist()
        self._index = pd.IntervalIndex.from_arrays(left, right)
        self._dim = dim
        self._bounds_dim = (set(bounds.dims) - set(dim)).pop()

    @classmethod
    def from_variables(cls, variables, options):
        assert len(variables) == 2
        return cls(variables)

    # TODO: variables=None?
    # set_xindex tries to pass variables; this seems like a bug
    def create_variables(self, variables=None):
        idx_variables = {}

        bounds = Variable(
            dims=(self._bounds_dim, self._dim),
            data=np.stack([self._index.left, self._index.right], axis=0),
        )
        mid = Variable(dims=(self._dim,), data=self._index.mid)
        return {self._dim: mid, self._bounds_name: bounds}

    # TODO: see notes about IndexSelResult
    #    The latter is a small class that stores positional indexers (indices)
    #    and that could also store new variables, new indexes,
    #    names of variables or indexes to drop,
    #    names of dimensions to rename, etc.
    def sel(self, labels, **kwargs):
        # sel needs to only handle keys in labels
        # since it delegates to isel.
        # we handle all entries in ._indexes there
        results = self.index.sel({k: labels[k]}, **kwargs)
        return merge_sel_results(results)

    def isel(self, indexers):
        # TODO: check dim names in indexes
        results = {}
        for k, index in self._indexes.items():
            if k in indexers:
                # again possible KDTree / friends here.
                results[k] = index.isel({k: indexers[k]})
            else:
                results[k] = index
        return type(self)(results)

    def __repr__(self):
        string = f"Xarray{self._index!r}"
        return string


newds = ds.drop_indexes("x").set_xindex(
    ("x", "x_bounds"),
    XarrayIntervalIndex,
)
newds

In [79]:
newds.data

In [63]:
pd.IntervalIndex.from_arrays(left, right).left

Float64Index([0.5, 1.5, 2.5, 3.5], dtype='float64')

In [48]:
ds.x_bounds.data.transpose().tolist()

[[0.5, 1.5], [1.5, 2.5], [2.5, 3.5], [3.5, 4.5]]

## Approach 4. Extension Array Adaptor

In [21]:
intarray = pd.arrays.IntervalArray.from_arrays([0, 1, 2], [1, 2, 3])

print(
    isinstance(intarray, pd.api.extensions.ExtensionArray),
    intarray.ndim,
    intarray.size,
    intarray.dtype,
    intarray.shape,
)

True 1 3 interval[int64, right] (3,)


In [22]:
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import ExplicitlyIndexed, PandasIndexingAdapter
from xarray.core.utils import NdimSizeLenMixin


class ExtensionArrayAdaptor(NdimSizeLenMixin, ExplicitlyIndexed):
    def __init__(self, array):
        self.array = array

    def __array__(self):
        raise

    def get_duck_array(self):
        return self

    def values(self):
        raise

    @property
    def shape(self):
        return self.array.shape

    @property
    def data(self):
        return self.array

    @property
    def dtype(self):
        return self.array.dtype

    def __array_ufunc__(self, *args, **kwargs):
        return self.array.__array_ufunc__(*args, **kwargs)

    def __array_function__(self, func, types, args, kwargs):
        raise NotImplementedError

    def __repr__(self):
        return f"ExtensionArrayAdaptor({repr(self.array)})"


@xr.register_dataarray_accessor("interval")
class IntervalAccessor:
    def __init__(self, obj):
        assert isinstance(obj, xr.DataArray)
        self._variable = obj._variable.to_base_variable()
        assert isinstance(self._variable._data, PandasIndexingAdapter)
        array = self._variable._data.array.array
        assert isinstance(array, pd.arrays.IntervalArray)

        self._obj = obj
        self._array = array

    def _get_property_var(self, prop):
        vals = getattr(self._array, prop)
        return self._variable.copy(data=np.array(vals))

    def _wrap_property(self, prop):
        return self._obj._replace(self._get_property_var(prop))

    @property
    def mid(self):
        return self._wrap_property("mid")

    @property
    def left(self):
        return self._wrap_property("left")

    @property
    def right(self):
        return self._wrap_property("right")

    def to(self, loc):
        assert loc in ["left", "mid", "right"]
        var = self._get_property_var(loc)
        dim = self._obj.name

        coords = self._obj._coords.copy()
        indexes = self._obj._indexes.copy()

        indexes[dim] = PandasIndex(var._data, dim=dim)
        coords[dim] = coords[dim].copy(data=indexes[dim].index)

        return self._obj._replace(var, coords=coords, indexes=indexes)


da = xr.DataArray(
    np.arange(len(intarray)),
    dims="x",
    coords={"x": ExtensionArrayAdaptor(intarray)},
)
da

  class IntervalAccessor:


RuntimeError: No active exception to reraise

In [6]:
from xarray.core.pycompat import is_duck_array

a = ExtensionArrayAdaptor(intarray)
is_duck_array(a)

True

In [7]:
da._indexes["x"]

PandasIndex(IntervalIndex([(0, 1], (1, 2], (2, 3]], dtype='interval[int64, right]', name='x'))

In [8]:
da

In [9]:
da.x.interval.mid

In [10]:
da.x

In [11]:
v = xr.Variable("x", ExtensionArrayAdaptor(intarray))
v._data

ExtensionArrayAdaptor(<IntervalArray>
[(0, 1], (1, 2], (2, 3]]
Length: 3, dtype: interval[int64, right])

In [12]:
v = xr.IndexVariable("x", ExtensionArrayAdaptor(intarray))
v._data

PandasIndexingAdapter(array=IntervalIndex([(0, 1], (1, 2], (2, 3]], dtype='interval[int64, right]'), dtype=dtype('O'))

In [17]:
v.isel(x=0)

In [18]:
xr.IndexVariable("x", ExtensionArrayAdaptor(intarray))