# Creating a custom Xarray index

## STILL TO DO
- plug in diff transform (or leave for later)
- move text over from outline


This tutorial demonstrates creating a custom Xarray index object using `.set_xindex()`. We detail scenarios where non-standard index functionality is helpful and walk through the function and components of Xarray `Index` objects.

In [1]:
import xarray as xr 
import numpy as np
import pandas as pd
from collections.abc import Sequence
from copy import deepcopy


from xarray import Index
from xarray.core.indexes import PandasIndex
from xarray.core.indexing import merge_sel_results
import matplotlib.pyplot as plt
from xarray.core.indexes import Index, PandasIndex, get_indexer_nd
from xarray.core.indexing import merge_sel_results

## Define sample data
Create sample data. This will be a one-dimensional `xr.Dataset` with an index along the `'x'` dimension. We would like to query this dataset from a different coordinate reference system (eg. longitude). Information describing the relationship between the two coordinate reference systems will be stored in the attributes of a `spatial_ref` object.

In [2]:
def create_sample_data(kwargs):
    attrs = {
    'factor' : kwargs['factor'],
    'range' : kwargs['range'],
    'idx_name' : kwargs['idx_name'],
    'real_name' : kwargs['real_name'] 
    }
    da = xr.DataArray(
        data = np.random.rand(kwargs['data_len']),
        dims = (kwargs['idx_name']),
        coords = {
            'x':np.arange(kwargs['range'][0], kwargs['range'][1],kwargs['range'][2]),
            })
    ds = xr.Dataset({'var1':da
                    })
    spatial_ref = xr.DataArray()
    spatial_ref.attrs = attrs

    ds['spatial_ref'] = spatial_ref
    ds = ds.set_coords('spatial_ref')

    ds = ds.expand_dims({'y':1})
    return ds

In [3]:
def make_kwargs(factor, range_ls, data_len):
    da_kwargs = {
        'factor': factor,
        'range' : range_ls,
        'idx_name':'x',
        'real_name':'lon',
        'data_len': data_len
    }
    return da_kwargs

In [4]:
sample_ds1 = create_sample_data(make_kwargs(2,[0,10,1],10))
sample_ds2 = create_sample_data(make_kwargs(5,[8,18,1], 10))

create copy for later

In [5]:
old_ds1 = sample_ds1.copy()
old_ds2 = sample_ds2.copy()

Spend some time here motivating problem, transfer over text on limitations
- emphasis parts of methoda s needed later but don't walk thru entire thing 
- think about specific parts of set_xindex notebook 

## Define a custom index class

In [6]:
#make suer this is well commented, instead of 
class ToyIndex_scalar(xr.Index): #customindex inherits xarray Index
    def __init__(self, x_indexes, variables=None): #added =None trying to fix .join(), 7/9
        
        self.indexes = variables
        self._xindexes = x_indexes 
        if variables is not None:

            self.spatial_ref = variables['spatial_ref']
        else:
            self.spatial_ref = None
    @classmethod          
    def from_variables(cls,variables, **kwargs):
        '''this method creates a ToyIndex obj from a variables object.
        variables created like this:
        coord_vars = {name:ds._variables[name] for name in coord_names}
        coord_names is passed to set_xindex
        '''
        #print('variables ', variables)
        assert len(variables) == 2
        assert 'x' in variables
        assert 'spatial_ref' in variables 
        
        dim_variables = {}
        scalar_vars = {}
        for k,i in variables.items():
            if variables[k].ndim ==1:
                dim_variables[k] = variables[k]
            if variables[k].ndim ==0:
                scalar_vars[k] = variables[k]
        
        options = {'dim':'x',
                   'name':'x'}
        
        x_indexes = {
            k: PandasIndex.from_variables({k: v}, options = options) 
            for k,v in dim_variables.items()
        }
        
        x_indexes['spatial_ref'] = variables['spatial_ref']
        
        return cls(x_indexes, variables)
    
    def create_variables(self, variables=None):
        '''creates coord variable from index'''
        if not variables:
            variables = self.joined_var

        idx_variables = {}
        

        for index in self._xindexes.values():
            #want to skip spatial ref
            if type(index) == xr.core.variable.Variable:
                pass
            else:

                x = index.create_variables(variables)
                idx_variables.update(x)
                
        idx_variables['spatial_ref'] = variables['spatial_ref']          
        return idx_variables

    def transform(self, value):
        
        #extract attrs
        fac = self.spatial_ref.attrs['factor']
        key = self.spatial_ref.attrs['idx_name']

        #handle slice
        if isinstance(value, slice):
            
            start, stop, step = value.start, value.stop, value.step
            new_start, new_stop, new_step = start / fac, stop/fac, step
            new_val = slice(new_start, new_stop, new_step)
            transformed_labels = {key: new_val}
            return transformed_labels
        
        #single or list of values
        else:
        
            vals_to_transform = [] 

            if not isinstance(value, Sequence):
                value = [value]

            for k in range(len(value)):

                val = value[k]
                vals_to_transform.append(val)

            #logic for parsing attrs, todo: switch to actual transform
            transformed_x = [int(v / fac) for v in vals_to_transform]

            transformed_labels = {key:transformed_x}
            return transformed_labels

    def sel(self, labels):
        
        assert type(labels) == dict

        #user passes to sel
        label = next(iter(labels.values()))

        #materialize coord array to idx off of
        params = self.spatial_ref.attrs['range']
        full_arr = np.arange(params[0], params[1], params[2])
        toy_index = PandasIndex(full_arr, dim='x')

        #transform user labesl to coord crs
        idx = self.transform(label)

        #sel on index created in .sel()
        matches = toy_index.sel(idx)

        return matches 
        

    def equals(self, other):
        
        result = self._xindexes['x'].equals(other._xindexes['x']) and self._xindexes['spatial_ref'].equals(other._xindexes['spatial_ref'])
        
        return result

    def join(self, other, how='inner'):

        #make self index obj
        params_self = self.spatial_ref.attrs['range']
        full_arr_self = np.arange(params_self[0], params_self[1], params_self[2])
        toy_index_self = PandasIndex(full_arr_self, dim='x')
        

        #make other index obj
        other_start = other._xindexes['x'].index.array[0]
        other_stop = other._xindexes['x'].index.array[-1]
        other_step = np.abs(int((other_start-other_stop) / (len(other._xindexes['x'].index.array)-1)))
        
        
        params_other = other.spatial_ref.attrs['range']
        full_arr_other = np.arange(other_start, other_stop, other_step) #prev elements of params_other
        toy_index_other = PandasIndex(full_arr_other, dim='x')
        
        self._indexes = {'x': toy_index_self}
        other._indexes = {'x':toy_index_other}
        
        
        new_indexes = {'x':toy_index_self.join(toy_index_other, how=how)}
        
        #need to return a ToyIndex obj, but don't want to have to pass variables
        # so need to add all of the things that ToyIndex needs to new_indexes before passign it to return?
        
        #this will need to be generalized / testsed more
        new_indexes['spatial_ref'] =  deepcopy(self.spatial_ref) #this needs to get updated wtih new range ? 
        start = int(new_indexes['x'].index.array[0])
        stop = int(new_indexes['x'].index.array[-1])
        step = int((stop-start) / (len(new_indexes['x'].index.array) -1))
        
        new_indexes['spatial_ref'].attrs['range'] = [start, stop, step]
        
        idx_var = xr.IndexVariable(dims=new_indexes['x'].index.name,
                                   data = new_indexes['x'].index.array)
        attr_var = new_indexes['spatial_ref']
                              
        idx_dict = {'x':idx_var, 
                   'spatial_ref':attr_var}
        
        new_obj = type(self)(new_indexes)
        new_obj.joined_var = idx_dict
        return new_obj
        

    def reindex_like(self, other, method=None, tolerance=None):

        params_self = self.spatial_ref.attrs['range']
        full_arr_self = np.arange(params_self[0], params_self[1], params_self[2])
        toy_index_self = PandasIndex(full_arr_self, dim='x')
       
        toy_index_other = other._xindexes['x']
    
        d = {'x': toy_index_self.index.get_indexer(other._xindexes['x'].index, method, tolerance)}
               
        return d
        
     
        

## Checking out custom idnex

*** reindex_like needs to return an object like variables to pass to create vars (?)

In [10]:
sample_ds1 = sample_ds1.drop_indexes('x')
sample_ds2 = sample_ds2.drop_indexes('x')

In [11]:
#%pdb off

In [12]:
ds1 = sample_ds1.set_xindex(['x','spatial_ref'], ToyIndex_scalar)
ds2 = sample_ds2.set_xindex(['x','spatial_ref'], ToyIndex_scalar)


## Align

In [13]:
inner_align, _ = xr.align(ds1, ds2, join='inner')

In [14]:
outer_align, _ = xr.align(ds1, ds2, join='outer')

In [15]:
outer_align

In [16]:
inner_align

In [17]:
# is these wrkign related to reindex_like not being implemented for PandasIndx? 
# but that defaults to inner, and these are successsfuly producing left and right so shouldn't be it
#left_align,_ = xr.align(ds1, ds2, join='left')
#right_align,_ = xr.align(ds1, ds2, join='right')

## Selection

In [18]:
ds1.sel(x=14)

In [19]:
assert ds1.sel(x=14) == old_ds1.sel(x=7)

In [20]:
ds1.sel(x=[8,10,14])

In [21]:
# dim order swtiches? so need to specify data to assert
assert np.array_equal(ds1.sel(x=[8,10,14])['var1'].data, old_ds1.sel(x=[4,5,7])['var1'].data)

In [22]:
ds1.sel(x=slice(4,18))

In [23]:
assert np.array_equal(ds1.sel(x=slice(4,18))['var1'].data, old_ds1.sel(x=slice(2,9))['var1'].data)