# Creating a custom Xarray index object

## What is this

A notebook demonstrating how to create a custom index object with Xarray. This notebook specifically details an example where there is information for a necessary coordinate transformation stored in an attribute of an Xarray object. The notebook will start with the simplest implementation of a custom index before adding additional complexity and functionality. 

## Why?

(taken from [here](https://docs.xarray.dev/en/stable/internals/how-to-create-custom-index.html)), add more description

Xarray label-based indexing relies on `pandas.Index` objects. This can lead to limitations such as: 

- You don’t want all coordinate labels explicitly loaded in memory
- It’s difficult to fit irregular data within `pandas.Index` structure
- There is no built-in support for situations that require additional metadata

## Learning goals

- Understand the role and function of Xarray indexes
- Understand the different object classes related to Xarray indexes and how they interrelate
- Implement class methods within a `CustomIndex` object class to instruct the creation of a new type of index and coordinate variables based on the index
- Implement custom methods such as `.sel()` to add on-the-fly transformation between different coordinate systems

## Some background 
- copying this over from notion doc for now. too much text, break up and put sections of this as needed interspersed with code examples down below

### What are indexes?

Xarray objects (`xr.DataArray` and `xr.Dataset`) can store multi-dimensional array data with features like labeled coordinate dimensions and label-based indexing and selection that simplify interacting with multidimensional data. The mappings involved in these operations are performed by `Index` objects. `Index` objects are always associated with one or more coordinate variables, their function is to translate the data from array-space to label-space. For example, calling `.isel(x=10)` on an Xarray object will return the data stored in the 11th position of the `x` coordinate. In this case, the user is asking for the data by directly referencing the position of the data in array-space; there is no need for an index. Frequently, you may know the coordinate value of the data you would like to select but not its position along the dimension; this is called ‘label-based’ selection. Here, the user asks for the data by providing a label associated with a point along the `x` dimension, the `Index` converts the label to the position of the data within the array and returns the correct value (`ds.sel(x=12)`).

### Label-based indexing

Xarray’s label-based indexing operations implicitly use `pandas.Index` objects. These are powerful but can be rigid. For use-cases not supported by `pandas.Index` objects, Xarray allows users to create custom index classes. These will contain the basic components of a standard `xarray.Index` object, but with increased flexibility. Label-based operations that are execute by the user at the `xr.Dataset` or `xr.DataArray` level must be implemented on `xr.Index` objects associated with those arrays. For example, to create a 

### Related class objects (prob take this out)

#### `xr.Index`

This is the main base class for all Xarray index objects - any custom index class must inherit `xr.Index`. However, `xr.Index` objects are never created directly. Instead, `xr.Index` objects must be created from sub-classes of Index, usually `set_xindex` which this tutorial will demonstrate. 

#### `xr.PandasIndex`

This is an example of a custom Xarray index. It is a child class of `xr.Index` and was built in order to wrap `pd.Index` objects. 

### Related class methods

Any custom `xr.Index` class must contain a few key elements:

- A `xr.Index.from_variables()` method
    - This is a constructor method; unlike some objects, you (mostly) won’t create an instance of `xr.CustomIndex` with `__init__()`. Instead, `xr.CustomIndex` will be passed along with an `xr.DataArray` or `xr.Dataset` (we will call `da`) object to `xr.set_xindex()`. This function will call `xr.CustomIndex.from_variables()` and create an Index object using the variables of `da`.

In [1]:
import xarray as xr 
import numpy as np
import pandas as pd

from xarray import Index
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import merge_sel_results

First, we'll define some data. This is a one-dimensional `xr.DataArray` that has an `x` dimension, coordinate values describing the `x` dimension, and a data variable that exists along the `x` dimension. The `da` DataArray as well as the `da.x` coordinate variable both have attributes that will be used later. 

In [2]:
da = xr.DataArray(
    data = np.random.rand(10),
    dims = ('x'),
    coords = {
        'x':np.arange(10),
        })
da.attrs['xkey'] = '2'
da.attrs['xvals'] = list(np.arange(10))
da.x.attrs['xkey'] = '2'
da.x.attrs['xvals'] = list(np.arange(10))


In [3]:
da

Note that creating `da` with an `x` coordinate variable that exists along the `x` dimension automatically creates an Index object. 

The index is visible in the html repr of `da` and by calling `da.indexes`. It has an associated name, 'x', and contains the values that describe the `x` dimensional coordinate. 

In [4]:
da.indexes['x'].name

'x'

In [5]:
da.indexes

Indexes:
    x        Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='x')

**Note the various object types:**
- `da.indexes` is an `xarray` object that holds all of the indexes associated with an object. 

In [6]:
type(da.indexes)

xarray.core.indexes.Indexes

```{note}
how do access the `xr.PandasIndex` object here? its not public ? 
```
- `da.indexes['x']` is a `pandas.Index` object

In [7]:
type(da.indexes['x'])

pandas.core.indexes.base.Index

## What do we want to do? 

Currently we have a 1D object that is indexed along the x dimension, so we can ask for data using 'x' labels and the index will return them. What if we wanted to interact with this data using a different coordinate system? Let's look at the example of wanting to query a dataset with an `x` dimension by specifying coordinate values of some `longitude` dimension. To do this, we will need to create a custom `xr.Index` and attach it to the DataArray.

In [8]:
class ToyIndex1d(xr.Index): #customindex inherits xarray Index
    def __init__(self, x_indexes):
        
        #assert 'x' in variables.keys()
        #assert 'spatial_ref' in variables.keys()
        
        #variables.pop('spatial_ref') #still need to make transform fn get info from spatial ref
        #self.variables = variables
        #self.spatial_ref = variables['spatial_ref']

        self._x_indexes = x_indexes
        
    @classmethod
    def from_variables(cls,variables,**kwargs):
        '''this method creates a ToyIndex1d obj from a variables object.
        variables is a dict where key=name of variable, value is corresponding xr.variable
        
        -creation of index happens here
        don't want to transform variables here, just want to create idnex 
        '''
        #print(variables.keys())
        options = {'dim':'x',
                   'name':'x'}
        x_indexes = {
            #'transformed_{k}':
            #'x': PandasIndex.from_variables({k: attr_transform_1d(k,v)}, options=options) for k,v in variables.items() 
            'x': PandasIndex.from_variables({k: v}, options = options) for k,v in variables.items()
        }
        print(type(x_indexes.items()))
        print(type(x_indexes['x']))

        return cls(x_indexes)  

First, drop the 'x' index so that the object has no indexes.

In [9]:
da = da.drop_indexes('x')

In [10]:
da

Then, use `xr.set_xindex()` to createa a new ToyIndex1d object and assign it to `da`.

In [11]:
da_new = da.set_xindex('x', ToyIndex1d)

<class 'dict_items'>
<class 'xarray.core.indexes.PandasIndex'>


In [12]:
da_new

### Breaking down `from_variables()`

We can see that the object is the same, except that the index is now a `ToyIndex1d` type object instead of a `PandasIndex` object. Let's take a more careful look at how we did this. Right now, `ToyIndex1d` is pretty simple, it only contains the required constructor method `.from_variables()`. When `ToyIndex1d` is passed to `xr.set_xindex()`, a few things happen under the hood:

```{note}
including each of these steps might be too much detail 
```

#### 1. A `variables` object is created.
This is a dictionary containing the name and associated variable for each coordinate of `da`. Here is an example of how it is created outside of the method:

In [13]:
variables = {name: da[name].variable for name in da.coords}

In [14]:
variables['x'].attrs

{'xkey': '2', 'xvals': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

In [15]:
variables['x']

#### 2. A new indexes dict is created. 
`x_indexes` is a dict where the key is a name of an index and the values are `xr.PandasIndex` objects created from `variables`. This is the new index. 

In [16]:
options = {'dim':'new_x',
           'name':'x'}
x_indexes = {'new_x':PandasIndex.from_variables({k:v}, options=options) for k,v in variables.items()
            }


See that `xr.PandasIndex` are objects that wrap `pd.Index` objects

In [17]:
print(x_indexes['new_x'])
print(type(x_indexes['new_x']))
print(type(x_indexes['new_x'].index))

PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='x'))
<class 'xarray.core.indexes.PandasIndex'>
<class 'pandas.core.indexes.base.Index'>


In [18]:
x_indexes['new_x'].index

Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='x')

```{note}
what do options do / how do you see what they cahnge? 
- **look up xr.PandasIndex.from_variables()**
```

#### 3. A `ToyIndex1d` object is being returned
`from_variables()` returns cls(x_indexes). CLS is a python convention for constructing a class object (instead of calling `__init__()`). 

### Adding some complexity to `ToyIndex1d`

We have created an index class with the bare minimum requirements. Let's add some additional functionality. The `create_variables()` method creates a coordinate variabel from an index object.

In [19]:
class ToyIndex_2(xr.Index): #customindex inherits xarray Index
    def __init__(self, variables):
        
        assert 'x' in variables.keys()
        #assert 'spatial_ref' in variables.keys()
        
        #variables.pop('spatial_ref') #still need to make transform fn get info from spatial ref
        #self.spatial_ref = variables['spatial_ref']

        self._indexes = variables
            
    @classmethod
    def from_variables(cls,variables,**kwargs):
        '''this method creates a ToyIndex1d obj from a variables object.
        variables is a dict where key=name of variable, value is corresponding xr.variable
        
        -creation of index happens here
        don't want to transform variables here, just want to create idnex 
        '''
        #print(variables.keys())
        options = {'dim':'x',
                   'name':'x'}
        x_indexes = {
            'x': PandasIndex.from_variables({k: v}, options = options) for k,v in variables.items()
        }
        #print(type(x_indexes.items()))
        #print(type(x_indexes['new_x']))

        return cls(x_indexes)  #remoevd variables from here


    def create_variables(self, variables):
       
        idx_variables = {}
        
        for index in self._indexes.values():
            
            print(self._indexes.keys())
            i = PandasIndex(index, dim='x')
            
            #i.index=i.index.rename('lon')
            x = i.create_variables(variables)
           
            idx_variables.update(x)
            #print(idx_variables.keys())
        return idx_variables

In [20]:
da = xr.DataArray(
    data = np.random.rand(10),
    dims = ('x'),
    coords = {
        'x':np.arange(10),
        })
da.attrs['xkey'] = '2'
da.attrs['xvals'] = list(np.arange(10))
da.x.attrs['xkey'] = '2'
da.x.attrs['xvals'] = list(np.arange(10))


In [21]:
da = da.drop_indexes('x')

In [22]:
da_new1 = da.set_xindex('x', ToyIndex_2)

dict_keys(['x'])


In [23]:
da_new1

## Implement .sel()

In [24]:
class ToyIndex_2(xr.Index): #customindex inherits xarray Index
    def __init__(self, variables):
        
        #self.key = key
        assert 'x' in variables.keys()
        #assert 'xkey' in variables['x'].attrs
        #assert 'spatial_ref' in variables.keys()
        
        #variables.pop('spatial_ref') #still need to make transform fn get info from spatial ref
        #self.spatial_ref = variables['spatial_ref']
        #self.x_attrs = variables['x'].attrs
        #variables['x'].attrs
        self._indexes = variables
        
    def pass_attrs(self, variables):
        
        assert 'xkey' in variables['x'].attrs
        
        key = variables['x'].attrs['xkey']
        return k
    @classmethod
    def from_variables(cls,variables,**kwargs):
        '''this method creates a ToyIndex1d obj from a variables object.
        variables is a dict where key=name of variable, value is corresponding xr.variable
        
        -creation of index happens here
        don't want to transform variables here, just want to create idnex 
        '''
        options = {'dim':'x',
                   'name':'x'}
        x_indexes = {
            'x': PandasIndex.from_variables({k: v}, options = options) for k,v in variables.items()
        }

        return cls(x_indexes)  #remoevd variables from here


    def create_variables(self, variables):
       
        idx_variables = {}
        
        for index in self._indexes.values():
            
            i = PandasIndex(index, dim='x')
            
            #i.index=i.index.rename('lon')
            x = i.create_variables(variables)
           
            idx_variables.update(x)
            #print(idx_variables.keys())
        return idx_variables
    
    def sel(self, labels):
        
        #label = next(iter(label.values()))
        #print('labels ', labels)
        
        def transform(value):
            assert type(value) == dict
            
            key = list(value.keys())[0] #pull out key, should be 'x'
            list(variables.keys())[0]
           # print('val key ', value.keys()) #use this to replace hadrcode in transformd_labels, x
           # print('value ' ,value)
            transformed_x = int(value[key]/2) #add this as an arg instead of hard code
           # print('trx ', transformed_x)
            
            transformed_labels = {key:transformed_x} #pass key from value.keys() isnetad of hardcode
            return transformed_labels
        
        #print('labels["x"]: ', labels['x'])
        idx = transform(labels)
        #print('idx ', idx)
        results = []
        #print('labels: ', labels)
        for k, index in self._indexes.items():
            
            #if key in labels: 
            if k in idx:
                                
                #print(type(index[1].xindexes))
                results.append(index.sel({k: idx[k]}))
                
        return merge_sel_results(results)

In [25]:
da = xr.DataArray(
    data = np.random.rand(10),
    dims = ('x'),
    coords = {
        'x':np.arange(10),
        })
da.attrs['xkey'] = '2'
da.attrs['xvals'] = list(np.arange(10))
da.x.attrs['xkey'] = '2'
da.x.attrs['xvals'] = list(np.arange(10))


In [26]:
da.sel(x=7)

In [27]:
da = da.drop_indexes('x')

In [28]:
da_new1 = da.set_xindex('x', ToyIndex_2)

In [29]:
da_new1.sel(x=14)

Cool, `.sel()` works. In this simplified example we created a custom Xarray index with label- based selection that handles the translation between two coordinate systems. 

## Questions:
    - same object seems to be created if `create_variables()` is included or not? 
         -  unsure on what exactly create_vars doing
            - it is creating a coordinate variable from the customindex object, but that coordinate exists without create_vars?