# `roll_to` and `sel_periodic` (version 2)

This aims to help with the problem of periodic/ circular/ wrapped/ modulo selection described in [Circular longitude axis #623](https://github.com/pydata/xarray/issues/623). [@j08lue](https://github.com/j08lue) uses rolled index array as follows:

``` python
def wrap_iselkw(ds, dim, istart, istop):
    """Returns a kw dict for indices from `istart` to `istop` that wrap around dimension `dim`"""
    n = len(ds[dim])
    if istart > istop:
        istart -= n
    return { dim : np.mod(np.arange(istart, istop), n) }
```

In [roll method #624](https://github.com/pydata/xarray/issues/623) [@rabernat](https://github.com/rabernat) proposed `roll` method (that has already been added to xarray) as follows:

```python
def roll(darr, n, dim):
    """Clone of numpy.roll for xray objects."""
    left = darr.isel(**{dim: slice(None, -n)})
    right = darr.isel(**{dim: slice(-n, None)})
    return xray.concat([right, left], dim=dim, data_vars='minimal',
                       coords='minimal')
```

[@rabernat's](https://github.com/rabernat) method appears to be better because it deals with 2 sub-arrays instead of N columns. Both APIs operate work with integer indices and do not support coordinate labels. 

We are introducing function `sel_periodic` and helper function `roll_to` that work in coordinate label space.

`sel_periodic(ds, dim, value, period=360.0)` works similar to `sel` only assumes that the coordinate is periodic, that it covers exactly one period, has monotonically increasing labels, and includes label 0.0. It supports scalar, array, and slice selections. In all cases it normalizes selection criteria (modulo `period`), and if `slice.start` is greater than `slice.stop`, it rolls the dataset to `slice.start` to ensure contiguous selection. Also it applies modulo `period` normalization to the coordinate. It supports positive and negative `slice.step`. TODO: change API to match xarray's `sel`.

`roll_to(ds, dim, value, period=360.0)` is a helper function used by `sel_periodic`, and can also be used independently. It rolls the ds to the first coordinate label that is greater or equal to value, and then makes coordinate monitonically increasing. Assumes that input coordinate is monotonically increasing, covers exactly one period, and overlaps value. If value is outside of the coordinate range, then this function does nothing.


In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from collections.abc import Iterable

In [None]:
def __dim_range(ds, dim, period=360.0):
    c0, c1 = ds[dim].values[0], ds[dim].values[-1]
    d = (period - (c1 - c0)) / 2.0
    c0, c1 = c0 - d , c1 + d
    return c0, c1


def __normalize_vals(v0, vals, period=360.0, right=False):
    
    vs = vals if isinstance(vals, Iterable) else [vals]
    
    v1 = v0 + period
    assert v0 <= 0.0 <= v1
    
    vs = np.mod(vs, period)
    if right:
        vs[vs > v1] -= period
    else:
        vs[vs >= v1] -= period

    vs = vs if isinstance(vals, Iterable) else vs[0]

    return vs


def __normalize_dim(ds, dim, period=360.0):
    """Doesn't copy ds. Make a copy if necessary.""" 
    c0, c1 = __dim_range(ds, dim, period)
    if c0 > 0.0:
        ds[dim] = ds[dim] - period
    elif c1 < 0.0:
        ds[dim] = ds[dim] + period


def roll_to(ds, dim, val, period=360.0):
    """Rolls the ds to the first dim's label that is greater or equal to
    val, and then makes dim monitonically increasing. Assumes that dim
    is monotonically increasing, covers exactly one period, and overlaps
    val. If val is outside of the dim, this function does nothing.
    """
    a = np.argwhere(ds[dim].values >= val)
    n = a[0, 0] if a.shape[0] != 0 else 0
    if n != 0:
        ds = ds.copy()
        ds = ds.roll(**{dim: -n}, roll_coords=True)
        ds[dim] = xr.where(ds[dim] < val, ds[dim] + period, ds[dim])
        __normalize_dim(ds, dim, period)
    return ds


def sel_periodic(ds, dim, vals, period=360.0):
    """Assumes that dim is monotonically increasing, covers exactly one period, and overlaps 0.0
    Examples: lon: 0..360, -180..180, -90..270, -360..0, etc.
    TODO: change API to match xarray's `sel`
    """
    c0, c1 = __dim_range(ds, dim, period)
    print(f"*** sel_periodic (input): {c0}..{c1}: {vals}")

    if isinstance(vals, slice):
        if vals.step is None or vals.step >= 0:
            s0 = __normalize_vals(c0, vals.start, period)
            s1 = __normalize_vals(c0, vals.stop, period, True)
        else:
            s0 = __normalize_vals(c0, vals.stop, period)
            s1 = __normalize_vals(c0, vals.start, period, True)

        print(f"*** sel_periodic (normalized): {c0}..{c1}: {s0=}, {s1=}")

        if s0 > s1:
            ds = roll_to(ds, dim, s1, period)
            c0, c1 = __dim_range(ds, dim, period)
            s0 = __normalize_vals(c0, s0, period)
            s1 = __normalize_vals(c0, s1, period, True)
            print(f"*** sel_periodic (rolled): {c0}..{c1}: {s0=}, {s1=}")

        if vals.step is None or vals.step >= 0:
            vals = slice(s0, s1, vals.step)
        else:
            vals = slice(s1, s0, vals.step)

        print(f"*** sel_periodic (slice): {c0}..{c1}: {vals}") 

    else:
        vals = __normalize_vals(c0, vals, period=period)
        print(f"*** sel_periodic (array): {c0}..{c1}: {vals}")

    ds = ds.sel({dim: vals})
    
    return ds


In [None]:
ds = xr.open_dataset("/local/ikh/data5/noaa/cac/cmap/monthly/latest_rotating/cmap_mon_latest_float64.nc")
#ds = ds.pipe(roll_to, "lon", 0.0)
da = ds["rain1"]
da

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", np.array([-708.75, -1.25, 381.25, -176.25]))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", 731.25)["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(-370, 361, -1))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(361, -370, -1))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(361, -370))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(-370, 361))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(-160, 160))["lon"]

In [None]:
da.isel(time=0).pipe(sel_periodic, "lon", slice(-200, 200))["lon"]

In [None]:
pd.to_datetime(11, format='%m').month_name()

In [None]:
z = ds.sel(lon=slice(170, -170, -2))
__normalize_dim(z, "lon")
z["lon"].values