# Generic grid search using xarray


Create a generlized way of looping over many combination of parameters and saving the results in a xarray dataset. 



In [219]:
from matplotlib import pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
import os
import rioxarray
from typing import Optional, Tuple
import subprocess
import hvplot.xarray
from tqdm.autonotebook import tqdm

import sys
sys.path.append("..") 
from scripts.funcs import fsm, rectangular_melt_region, square_melt_region, map_water_depth

  from tqdm.autonotebook import tqdm


In [180]:
def returns_a_float(a=1111.0, b=5555.0, c=2222.0):
    if 0:
        print("start of call to f")
        print(f"a = {a}")
        print(f"b = {b}")
        print(f"c = {c}")
        print("end of call to f")

    return a+b+c

returns_a_float(1.0, 2.0, 3.0)


6.0

In [182]:
def returns_an_xrdataset(a=1111.0, b=5555.0, c=2222.0):
        # Create some data
    temperature = 15 + 8 * np.random.randn(2, 2)
    precipitation = 10 * np.random.rand(2, 2)

    # Create some dimensions
    lat = [10, 20]
    lon = [30, 40]
    time = pd.date_range('2000-01-01', periods=3)

    # Create some coordinate arrays
    latitudes = xr.DataArray(lat, dims='lat')
    longitudes = xr.DataArray(lon, dims='lon')

    # Create the Dataset
    return xr.Dataset(
        {
            'temperature': ([ 'lat', 'lon'], temperature),
            'precipitation': (['lat', 'lon'], precipitation)
        },
        coords={
            'lat': latitudes,
            'lon': longitudes
        }
    )
returns_an_xrdataset()

In [220]:
def gridSearch(function, **kwargs):
    
    # extract the names of the parameters
    p_names = [x for x in kwargs] 

    # extract the values of the parameters
    p_values_list = [x for x in kwargs.values()]
    
    # create a multiIndex from the parameter names and values
    multiIndex = pd.MultiIndex.from_product(p_values_list, names=p_names)


    #loop over every conbimation of parameters stored in multiIndex
    xr_out_list = []
    for mi in tqdm(multiIndex):

        # create a dictionary of inputs for the function from the values stored in multiIndex
        inputs = {p_names[x]: mi[x] for x in range(len(p_names))}

        # run the function with this combination of inputs
        single_iteration_result = function(**inputs)

        # add coordinates to the result
        if type(single_iteration_result) == xr.core.dataset.Dataset:
            xr_out_new = single_iteration_result.assign_coords(inputs)
        else:
            xr_out_new = xr.DataArray(single_iteration_result, coords=inputs)    # use this line if the function returns a data array, or a numpy array
        
        # append the result to a list
        xr_out_list.append(xr_out_new)

    # concatenate the list of results into a single xarray
    xr_stacked = xr.concat(xr_out_list, dim='stacked_dim')

    # add the multiIndex to the xarray
    mindex_coords = xr.Coordinates.from_pandas_multiindex(multiIndex, 'stacked_dim')
    xr_stacked = xr_stacked.assign_coords(mindex_coords)

    # unstack the xarray - i.e. separate the multiIndex into separate dimensions
    xr_unstacked = xr_stacked.unstack()

    return xr_unstacked

xr_unstacked = gridSearch(returns_an_xrdataset, a=[0.1, 0.2], b=[0.3, 0.4], c=[0.5, 0.6])
xr_unstacked = gridSearch(returns_a_float, a=[0.1, 0.2], b=[0.3, 0.4], c=[0.5, 0.6])

xr_unstacked



  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


In [187]:
def add_center_of_mass(results, 
                       x_center_of_melt, 
                       y_center_of_melt):
    # add the center of mass of the water (see centroid_test.ipynb for notes on ths method)
    weights = results.water_depth.fillna(0)
    results['x_center_of_mass'] = results.x.weighted(weights).mean(dim = ['x', 'y'])
    results['y_center_of_mass'] = results.y.weighted(weights).mean(dim = ['x', 'y'])
    results.x_center_of_mass.attrs = {'long_name': 'x coordinate of the center of mass', 'description': 'the x coordinate of the center of mass of the water, i.e. the depth-weighted centroid'}
    results.y_center_of_mass.attrs = {'long_name': 'y coordinate of the center of mass', 'description': 'the y coordinate of the center of mass of the water, i.e. the depth-weighted centroid'}
    results['L'] = ((results['x_center_of_mass'] - x_center_of_melt)**2 + (results['y_center_of_mass'] - y_center_of_melt)**2)**(1/2)
    results.L.attrs = {'long_name': 'distance between the center of mass and the center of the melt region', 'description': 'the distance between the center of mass and the center of the melt region'}
    return results

In [188]:
def fsm_xarray(dem_filename = "rema_subsets/dem_small_2.tif",
                            melt_magnitude = 0.1,
                            x_center_of_melt: float = 817500.0,
                            y_center_of_melt: float = 1.9325e6,
                            melt_width: float = 5000):
    # Load the DEM
    dem = rioxarray.open_rasterio(dem_filename, chunks={})
    dem = dem.squeeze()

 
    melt, melt_filename, bounds = square_melt_region(dem, 
                                                        melt_magnitude, 
                                                        x_center_of_melt=x_center_of_melt, 
                                                        y_center_of_melt=y_center_of_melt, 
                                                        width=melt_width)  
    water_depth = fsm(dem_filename, melt_filename=melt_filename)        
    
    # name the xr.DataArrays
    water_depth.name = 'water_depth'
    dem.name = 'dem'
    melt.name = 'melt'

    # add information about the coordinates and variables in attributes
    melt.attrs = {'units': 'meters', 'long_name': 'surface melt', 'description': 'the surface melt as a function of x and y'}

    # merge the xr.DataArrays into a xr.Dataset
    results = xr.merge([water_depth, dem, melt])
    results = results.drop_vars('band')   # drop this unneeded variable

    bounds = np.array(bounds)
    results['bounds'] = xr.DataArray(bounds, dims=['bounds_index'], name='bounds')
    results.bounds.attrs = {'long_name': 'bounds of the rectangular melt region', 'description': 'the bounds of the rectangular melt region: (xmin, ymin, xmax, ymax)'}

    results = add_center_of_mass(results, 
                       x_center_of_melt, 
                       y_center_of_melt)
     

    
    return results


In [None]:
results = fsm_xarray()
results

[2K

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 3 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'2fef3a1f-a9a4-4bb2-b326-b5787df4aab8': {'version…

In [157]:
results = gridSearch(fsm_xarray, x_center_of_melt = [812500.0, 817500.0], melt_magnitude=[0.2, 2])

[2K

In [158]:
results

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 8 graph layers,2 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 48.00 MiB 24.00 MiB Shape (3072, 1024, 2, 2) (3072, 1024, 1, 2) Dask graph 2 chunks in 8 graph layers Data type float32 numpy.ndarray",3072  1  2  2  1024,

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 8 graph layers,2 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 16 graph layers,2 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 48.00 MiB 24.00 MiB Shape (3072, 1024, 2, 2) (3072, 1024, 1, 2) Dask graph 2 chunks in 16 graph layers Data type float32 numpy.ndarray",3072  1  2  2  1024,

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 16 graph layers,2 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [159]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'087ab4a3-2577-4b63-9aff-f98c609b773c': {'version…

In [221]:
results = gridSearch(fsm_xarray, 
                     dem_filename = ["rema_subsets/dem_small_2.tif"],
                     melt_magnitude = [0.1, 10],
                     x_center_of_melt = [817500.0],
                     y_center_of_melt = [1.9325e6],
                     melt_width = [5000])
# to save space we can just keep one copy of the dem
dims_to_drop = list(set(results.dims) - set(('x','y','bounds_index')))
i = {dim_name: 0 for dim_name in dims_to_drop}
results['dem'] = results.dem.isel(i)
results

  0%|          | 0/2 [00:00<?, ?it/s]

[2K

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 8 graph layers,1 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 8 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 8 graph layers,1 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1, 2, 1, 1, 1)","(3072, 1024, 1, 1, 1, 1, 1)"
Dask graph,2 chunks in 11 graph layers,2 chunks in 11 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.00 MiB 12.00 MiB Shape (3072, 1024, 1, 2, 1, 1, 1) (3072, 1024, 1, 1, 1, 1, 1) Dask graph 2 chunks in 11 graph layers Data type float32 numpy.ndarray",3072  1  2  1  1024  1  1  1,

Unnamed: 0,Array,Chunk
Bytes,24.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1, 2, 1, 1, 1)","(3072, 1024, 1, 1, 1, 1, 1)"
Dask graph,2 chunks in 11 graph layers,2 chunks in 11 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [218]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'7ac383a0-9764-4edc-a3fa-0df0cf035bf8': {'version…

In [235]:
results = gridSearch(fsm_xarray, 
                     dem_filename = ["rema_subsets/dem_small_2.tif"],
                     melt_magnitude = [ 10],
                     x_center_of_melt = np.arange(800000.0, 832000, 5000),
                     y_center_of_melt = np.arange(1.912e6, 1.995e6, 5000),
                     melt_width = [5000])
# to save space we can just keep one copy of the dem
dims_to_drop = list(set(results.dims) - set(('x','y','bounds_index')))
i = {dim_name: 0 for dim_name in dims_to_drop}
results['dem'] = results.dem.isel(i)
results

  0%|          | 0/119 [00:00<?, ?it/s]

[2K

In [230]:
results

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 8 graph layers,1 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 8 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 8 graph layers,1 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,72.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1, 1, 6, 1, 1)","(3072, 1024, 1, 1, 1, 1, 1)"
Dask graph,6 chunks in 19 graph layers,6 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 72.00 MiB 12.00 MiB Shape (3072, 1024, 1, 1, 6, 1, 1) (3072, 1024, 1, 1, 1, 1, 1) Dask graph 6 chunks in 19 graph layers Data type float32 numpy.ndarray",3072  1  1  1  1024  1  1  6,

Unnamed: 0,Array,Chunk
Bytes,72.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1, 1, 6, 1, 1)","(3072, 1024, 1, 1, 1, 1, 1)"
Dask graph,6 chunks in 19 graph layers,6 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [227]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'0e7602c3-465f-446d-a0a9-54ba10327d31': {'version…

In [233]:
results.water_depth.dtype

dtype('float64')