# Looking into set_xindex()

In [1]:
import xarray as xr 
import numpy as np
import pandas as pd

from xarray import Index
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import merge_sel_results

In [2]:
from __future__ import annotations

import copy
import datetime
import inspect
import itertools
import math
import sys
import warnings
from collections import defaultdict
from collections.abc import (
    Collection,
    Hashable,
    Iterable,
    Iterator,
    Mapping,
    MutableMapping,
    Sequence,
)
from functools import partial
from html import escape
from numbers import Number
from operator import methodcaller
from os import PathLike
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
    from numpy.exceptions import RankWarning  # type: ignore[attr-defined,unused-ignore]
except ImportError:
    from numpy import RankWarning

import pandas as pd

from xarray.coding.calendar_ops import convert_calendar, interp_calendar
from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from xarray.core import (
    alignment,
    duck_array_ops,
    formatting,
    formatting_html,
    ops,
    utils,
)
from xarray.core import dtypes as xrdtypes
from xarray.core._aggregations import DatasetAggregations
from xarray.core.alignment import (
    _broadcast_helper,
    _get_broadcast_dims_map_common_coords,
    align,
)
from xarray.core.arithmetic import DatasetArithmetic
from xarray.core.common import (
    DataWithCoords,
    _contains_datetime_like_objects,
    get_chunksizes,
)
from xarray.core.computation import unify_chunks
from xarray.core.coordinates import (
    Coordinates,
    DatasetCoordinates,
    assert_coordinate_consistent,
    create_coords_with_default_indexes,
)
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.indexes import (
    Index,
    Indexes,
    PandasIndex,
    PandasMultiIndex,
    assert_no_index_corrupted,
    create_default_index_implicit,
    filter_indexes_from_coords,
    isel_indexes,
    remove_unused_levels_categories,
    roll_indexes,
)
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import (
    dataset_merge_method,
    dataset_update_method,
    merge_coordinates_without_align,
    merge_core,
)
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
    Self,
    T_ChunkDim,
    T_Chunks,
    T_DataArray,
    T_DataArrayOrSet,
    T_Dataset,
    ZarrWriteModes,
)
from xarray.core.utils import (
    Default,
    Frozen,
    FrozenMappingWarningOnValuesAccess,
    HybridMappingProxy,
    OrderedSet,
    _default,
    decode_numpy_dict_values,
    drop_dims_from_indexers,
    either_dict_or_kwargs,
    emit_user_level_warning,
    infix_dims,
    is_dict_like,
    is_duck_array,
    is_duck_dask_array,
    is_scalar,
    maybe_wrap_array,
)
from xarray.core.variable import (
    IndexVariable,
    Variable,
    as_variable,
    broadcast_variables,
    calculate_dimensions,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import array_type, is_chunked_array
from xarray.plot.accessor import DatasetPlotAccessor
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

if TYPE_CHECKING:
    from dask.dataframe import DataFrame as DaskDataFrame
    from dask.delayed import Delayed
    from numpy.typing import ArrayLike

    from xarray.backends import AbstractDataStore, ZarrStore
    from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
    from xarray.core.dataarray import DataArray
    from xarray.core.groupby import DatasetGroupBy
    from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult
    from xarray.core.resample import DatasetResample
    from xarray.core.rolling import DatasetCoarsen, DatasetRolling
    from xarray.core.types import (
        CFCalendar,
        CoarsenBoundaryOptions,
        CombineAttrsOptions,
        CompatOptions,
        DataVars,
        DatetimeLike,
        DatetimeUnitOptions,
        Dims,
        DsCompatible,
        ErrorOptions,
        ErrorOptionsWithWarn,
        InterpOptions,
        JoinOptions,
        PadModeOptions,
        PadReflectOptions,
        QueryEngineOptions,
        QueryParserOptions,
        ReindexMethodOptions,
        SideOptions,
        T_Xarray,
    )
    from xarray.core.weighted import DatasetWeighted
    from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint



In [10]:
class ToyIndex(xr.Index): #customindex inherits xarray Index
    def __init__(self, x_indexes): # <- how does x_indexes get passed to init ? 
        
        #assert 'x' in variables.keys()
        #self._indexes = variables
        self._xindexes = x_indexes
    @classmethod
    def from_variables(cls,variables,**kwargs):
        '''this method creates a ToyIndex1d obj from a variables object.
        variables is a dict where key=name of variable, value is corresponding xr.variable
        
        -creation of index happens here
        don't want to transform variables here, just want to create idnex 
        - this method expects a 1 dimensional object
        '''
        assert len(variables) == 1
        options = {'dim':'x',
                   'name':'x'}
        x_indexes = {
            'x': PandasIndex.from_variables({k: v}, options = options) for k,v in variables.items()
        }

        return cls(x_indexes)  #remoevd variables from here

In [3]:
da = xr.DataArray(
    data = np.random.rand(10),
    dims = ('x'),
    coords = {
        'x':np.arange(10),
        })
da.attrs['xkey'] = '2'
da.attrs['xvals'] = list(np.arange(10))
da.x.attrs['xkey'] = '2'
da.x.attrs['xvals'] = list(np.arange(10))
ds = xr.Dataset({'var1':da
                })

In [4]:
ds = ds.drop_indexes('x')

## `set_xindex()`
takes: 
- coord_names (str) (eg. 'x')
- index_cls <-- this is the custom index class we create
- **options <- options passed to index constructor

steps of `xr.set_xindex()`:   
1. This checks if coord_names is scalar or not a sequence object, if not, wraps coord_names as list
```{python} 
if is_scalar(coord_names) or not isinstance(coord_names, Sequence):
        coord_names = [coord_names]
```
2. if an `index_cls` type isn't passed, set it to `PandasIndex` or `PandasMultiIndex` depending on # coords

```{python}
if index_cls is None: 
    if len(coord_names) == 1:
        index_cls = PandasIndex
    else:
        index_cls = PandasMultiIndex
```
3. if `index_cls` is passed but isn't a child of `xr.Index` raise error
```{python}
else:
            if not issubclass(index_cls, Index):
                raise TypeError(f"{index_cls} is not a subclass of xarray.Index")
```

4.  Checking for invalid coords
- `.set_xindex()` is applied as a method of da, so coord_names shouldn't have any coords that aren't in da._coord_names (self._coord_names)

#### 1. check that `coord_names` is sequence

In [5]:
coord_names = 4

In [6]:
if is_scalar(coord_names) or not isinstance(coord_names, Sequence):
        coord_names = [coord_names]

In [7]:
coord_names

[4]

#### 2. Check that `index_cls` exists, inherits `xr.Index`

In [11]:
index_cls = ToyIndex

In [12]:
if index_cls is None: 
    if len(coord_names) == 1:
        index_cls = PandasIndex
    else:
        index_cls = PandasMultiIndex
        
else:
            if not issubclass(index_cls, Index):
                raise TypeError(f"{index_cls} is not a subclass of xarray.Index")


In [13]:
index_cls

__main__.ToyIndex

#### 3. Check for invalid coords
- `.set_xindex()` is applied as a method of da, so coord_names shouldn't have any coords that aren't in da._coord_names (self._coord_names)
- creates an `invalid_coords` object by subtracting ds._coord_names from coord_names
- if `invalid_coords` exists, 
    checks if the invalid coordinate:
        - is a data variable --> msg to use set_coords
        - is not a data variable --> msg that var doesn't exist

Pretend to pass `spatial_ref` to set_xindex knowing its not a coord of `ds`:

In [14]:
coord_names = ['x','spatial_ref']

Will produce `spatial_ref` as an invalid coord

In [15]:
invalid_coords = set(coord_names) - ds._coord_names

In [16]:
invalid_coords

{'spatial_ref'}

`set_xindex()` will then check if the variable passed doesn't exist or if they are just data variables and not coordinate variables

`ds` has variables: `x`, `var`, `spatial_ref`

In [17]:
if invalid_coords:
        msg = ["invalid coordinate(s)"]
        no_vars = invalid_coords - set(ds._variables)
        data_vars = invalid_coords - no_vars
        if no_vars:
            msg.append(f"those variables don't exist: {no_vars}")
        if data_vars:
            msg.append(
                f"those variables are data variables: {data_vars}, use `set_coords` first"
            )
        raise ValueError("\n".join(msg))


ValueError: invalid coordinate(s)
those variables don't exist: {'spatial_ref'}

In [18]:
coord_names = ['x']

In [19]:
invalid_coords = set(coord_names) - ds._coord_names
if invalid_coords:
        msg = ["invalid coordinate(s)"]
        no_vars = invalid_coords - set(ds._variables)
        data_vars = invalid_coords - no_vars
        if no_vars:
            msg.append(f"those variables don't exist: {no_vars}")
        if data_vars:
            msg.append(
                f"those variables are data variables: {data_vars}, use `set_coords` first"
            )
        raise ValueError("\n".join(msg))

#### 4. Check if coordinates already have index with `indexed_coords`
- compares the indexes of the data object (`ds._indexes`) to the passed coord_names, if an element of coord_names exists in ds._indexes, raise error saying coord already has index

In [20]:
ds._indexes

{}

In [21]:
coord_names = ['x']

`&` operator finds intersection of 2 sets -- returns only the elements in both sets

In [22]:
indexed_coords = set(coord_names) & set(ds._indexes)

In [23]:
indexed_coords

set()

#### 5. Create `coord_vars` 
- this is where `set_xindex()` takes information from `ds` to pass to `index_cls` (as `coord_vars`)

- `coord_vars` is a dict 
    - created by extracting the `xr.Variable` obj from ds for each element of `coord_names`  
        ex. `coord_names = ['x']`, `ds._variables` is dict w/ keys: 'x','var1'  
        ```{python} 
        coord_vars = {'x': xr.IndexVariable 'x', 
                                array: ...
                                attrs: ...}
        ```

In [24]:
ds

In [25]:
coord_names

['x']

In [26]:
ds._variables

{'x': <xarray.Variable (x: 10)> Size: 80B
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 Attributes:
     xkey:     2
     xvals:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'var1': <xarray.Variable (x: 10)> Size: 80B
 array([0.06937865, 0.1381784 , 0.19135759, 0.99689398, 0.83931757,
        0.9145519 , 0.21902417, 0.01673527, 0.43451646, 0.65392184])
 Attributes:
     xkey:     2
     xvals:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

In [27]:
coord_vars = {name: ds._variables[name] for name in coord_names}

In [28]:
coord_vars

{'x': <xarray.Variable (x: 10)> Size: 80B
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 Attributes:
     xkey:     2
     xvals:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

#### 6. Create index, new coordinate variables
Here `xr.set_xindex()` creates the new index object by calling the `from_variables()` constructor method of `index_cls` and passing `coord_vars` to it
- the new index will have the coords specified in coord_names
- `from_variables()` creates a dict w/ keys = coord name (hardcoded) and vals = output of PandasIndex.from_variables()

- this calls `index_cls.create_variables()` 
    - I didn't realize they both are called within `set_xindex()`, what if `create_variables()` not implemented ? 
    -- i think bc from_variables created a pandas index, pandasindex has this as a class method so doens't matter if its implemented in `ToyIndex`

In [29]:
coord_vars #this is the variables object in the ToyIndex class 

{'x': <xarray.Variable (x: 10)> Size: 80B
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 Attributes:
     xkey:     2
     xvals:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

`set_xindex()` creates the `index` object below:

In [30]:
index = ToyIndex.from_variables(coord_vars, options = {'dim':'x'})

In [31]:
index

<__main__.ToyIndex at 0x7fa2deba7d50>

A look at what occurs in `from_variables()`:

```{python} 
x_indexes = {
    'x': PandasIndex.from_variables({k:v}, options={'dim':'x'}) for k,v in coord_vars.items()
            }
```

a dictionary (x_indexes in class) is created where the key is the coord name, the value (d below) is a `PandasIndex` object created from `coord_vars`:

In [32]:
d = PandasIndex.from_variables({'x':coord_vars['x']},options={'dim':'x'})

In [33]:
x_indexes = {d}

In [34]:
x_indexes

{PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='x'))}

In [35]:
new_coord_vars = index.create_variables(coord_vars)

In [36]:
new_coord_vars

{'x': <xarray.Variable (x: 10)> Size: 80B
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 Attributes:
     xkey:     2
     xvals:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

#### 7. Check if new index is a `pandasmultiindex`
if yes:
-create `variables` dict and `indexes` dict
- if only 1 coordname passed, variables = ds._variables.copy(), 
                                indexes = ds._indexes.copy()
                                
if len(coordname) > 1: 
    - reorder variables, indexes so that coods w same index next to each other

#### 8. Type hint `variables`, `indexes`
- this is type hinting ? 
    - shows expected data types of variables
`variables: dict[Hashable, Variable]
- variables is a dict
- keys of variables will be hashable
- values of variables will be Variable type (`xr.Variable`?)

In [37]:
variables : dict[Hashable, Variable]
indexes : dict[Hashable, Index]

#### Make `variables`, `indexes`
- now taking `ds` (or self) and new index object, create `variables` and `indexes` that will replace the original values on ds when returned
- if only one coord:

Create `variables` by copying `._variables` from `ds`. 

In [38]:
variables = ds._variables.copy()

Do the same for `ds._indexes` but this will be empty because `ds` doesn't have any indexes in our case

In [39]:
indexes = ds._indexes.copy()

Use coord_names (passed to `.set_xindex()` to create `name`

In [40]:
name = list(coord_names).pop()
name

'x'

Check if `name` in `new_coord_vars` (output of `index_cls.create_variables()`)

In [41]:
new_coord_vars['x']

If the name of the desired index is in the name of the new coordinate variables created from `index_cls`, add a new key-value pair to `variables` with that `xr.Variable` object:

In [42]:
if name in new_coord_vars:
    variables[name] = new_coord_vars[name]

Add new element to `indexes` with the `index_cls` object created above with .from_variables()

In [43]:
indexes[name] = index

#### 8. Replace elements of `ds` with `variables`, `indexes` and `coord_names`
```{python} 
ds._replace(
    variables = variables,
    coord_names = ds._coord_names | set(coord_names),
    indexes=indexes
    )
```
    

In [44]:
ds_out = ds._replace(
    variables = variables,
    coord_names = ds._coord_names | set(coord_names),
    indexes = indexes,
)

In [45]:
ds