# Generic grid search using xarray


This NB develops a generalized way of looping over many combination of parameters and saving the results in a xarray dataset. 



In [1]:
from matplotlib import pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
import os
import rioxarray
from typing import Optional, Tuple
import subprocess
import hvplot.xarray
from tqdm.autonotebook import tqdm
import dask
import sys
sys.path.append("..") 
from scripts.funcs import fsm, square_melt_region, map_water_depth

  from tqdm.autonotebook import tqdm


In [2]:
from typing import Callable
import pandas as pd
import xarray as xr
from tqdm.autonotebook import tqdm
import xarray as xr

def gridSearch(function: Callable, **kwargs) -> xr.core.dataset.Dataset:
    """
    Perform a grid search by iterating over all combinations of input parameters and running a given function.

    Parameters:
    function (callable): The function to be executed for each combination of input parameters. This function should return either an xarray dataset, an xarray datarray, or a numpy array.
    **kwargs: Keyword arguments representing the input parameters and their corresponding values.

    Returns:
    xr_unstacked (xarray.core.dataset.Dataset): The concatenated and unstacked xarray dataset containing the results of the grid search.

    Example:
    #### Define a function to be executed for each combination of input parameters
    def my_function(param1, param2):
        ##### Perform some computation using the input parameters
        result = param1 + param2
        return result

    #### Perform a grid search by iterating over all combinations of input parameters
    results = gridSearch(my_function, param1=[1, 2, 3], param2=[4, 5])
    
    """

    # extract the names of the parameters
    p_names = [x for x in kwargs] 

    # extract the values of the parameters
    p_values_list = [x for x in kwargs.values()]
    
    # create a multiIndex from the parameter names and values
    multiIndex = pd.MultiIndex.from_product(p_values_list, names=p_names)


    #loop over every conbimation of parameters stored in multiIndex
    xr_out_list = []
    for mi in tqdm(multiIndex):

        # create a dictionary of inputs for the function from the values stored in multiIndex
        inputs = {p_names[x]: mi[x] for x in range(len(p_names))}

        # run the function with this combination of inputs
        single_iteration_result = function(**inputs)

        # add coordinates to the result and store as as either a DataSet or a dataArray
        if isinstance(single_iteration_result, xr.core.dataset.Dataset):
            xr_out_new = single_iteration_result.assign_coords(inputs)
        else:
            xr_out_new = xr.DataArray(single_iteration_result, coords=inputs)    # use this line if the function returns a data array, or a numpy array
        
        # append the result to a list
        xr_out_list.append(xr_out_new)

    # concatenate the list of results into a single xarray
    xr_stacked = xr.concat(xr_out_list, dim='stacked_dim')

    # add the multiIndex to the xarray
    mindex_coords = xr.Coordinates.from_pandas_multiindex(multiIndex, 'stacked_dim')
    xr_stacked = xr_stacked.assign_coords(mindex_coords)

    # unstack the xarray - i.e. separate the multiIndex into separate dimensions
    xr_unstacked = xr_stacked.unstack()

    # convert to a dataset if the result is a data array
    if isinstance(xr_unstacked, xr.DataArray):
        xr_unstacked.name = 'result'
        xr_unstacked = xr_unstacked.to_dataset()


    return xr_unstacked

### Define two functions to test the gridSearch function
def returns_an_xrdataset(a: float = 1111.0, b: float = 5555.0, c: float = 2222.0) -> xr.core.dataset.Dataset:
    # load an example dataset
    ds = xr.tutorial.open_dataset("air_temperature")
    # perform some computation on the dataset using the model parameters
    ds['air'] = a * ds['air'] + b * ds['air'] + c
    return ds

def returns_a_float(a: float = 1111.0, b: float = 5555.0, c: float = 2222.0) -> float:
    return a + b + c

# Test the function as follows. This should return a dataset with the outputs of the functions for every combination of the supplied parameter values.
# All other parameters (the ones not specified in the function call) are set to their default values.
ds = gridSearch(returns_an_xrdataset, a=[0.1, 0.2])
ds = gridSearch(returns_an_xrdataset, a=[0.1, 0.2], b=[0.3, 0.4])
ds2 = gridSearch(returns_an_xrdataset, a=[0.1, 0.2], b=[0.3, 0.4], c=[0.5, 0.6])

ds = gridSearch(returns_a_float, a=[0.1, 0.2])
ds = gridSearch(returns_a_float, a=[0.1, 0.2], b=[0.3, 0.4])
ds = gridSearch(returns_a_float, a=[0.1, 0.2], b=[0.3, 0.4], c=[0.5, 0.6])

ds2


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
def add_center_of_mass(results, 
                       x_center_of_melt, 
                       y_center_of_melt):
    # add the center of mass of the water (see centroid_test.ipynb for notes on this method)
    weights = results.water_depth.fillna(0)
    results['x_center_of_mass'] = results.x.weighted(weights).mean(dim = ['x', 'y'])
    results['y_center_of_mass'] = results.y.weighted(weights).mean(dim = ['x', 'y'])
    results.x_center_of_mass.attrs = {'long_name': 'x coordinate of the center of mass', 'description': 'the x coordinate of the center of mass of the water, i.e. the depth-weighted centroid'}
    results.y_center_of_mass.attrs = {'long_name': 'y coordinate of the center of mass', 'description': 'the y coordinate of the center of mass of the water, i.e. the depth-weighted centroid'}
    results['L'] = ((results['x_center_of_mass'] - x_center_of_melt)**2 + (results['y_center_of_mass'] - y_center_of_melt)**2)**(1/2)
    results.L.attrs = {'long_name': 'distance between the center of mass and the center of the melt region', 'description': 'the distance between the center of mass and the center of the melt region'}
    return results

In [4]:
def fsm_xarray(dem_filename = "rema_subsets/dem_small_2.tif",
                            melt_magnitude = 0.1,
                            x_center_of_melt: float = 817500.0,
                            y_center_of_melt: float = 1.9325e6,
                            melt_width: float = 5000):
    # Load the DEM
    dem = rioxarray.open_rasterio(dem_filename, chunks={})
    dem = dem.squeeze()

 
    melt, melt_filename, bounds = square_melt_region(dem, 
                                                        melt_magnitude, 
                                                        x_center_of_melt=x_center_of_melt, 
                                                        y_center_of_melt=y_center_of_melt, 
                                                        width=melt_width)  
    water_depth = fsm(dem_filename, melt_filename=melt_filename)        
    
    # name the xr.DataArrays
    water_depth.name = 'water_depth'
    dem.name = 'dem'
    melt.name = 'melt'

    # add information about the coordinates and variables in attributes
    melt.attrs = {'units': 'meters', 'long_name': 'surface melt', 'description': 'the surface melt as a function of x and y'}

    # merge the xr.DataArrays into a xr.Dataset
    results = xr.merge([water_depth, dem, melt])
    results = results.drop_vars('band')   # drop this unneeded variable

    bounds = np.array(bounds)
    results['bounds'] = xr.DataArray(bounds, dims=['bounds_index'], name='bounds')
    results.bounds.attrs = {'long_name': 'bounds of the rectangular melt region', 'description': 'the bounds of the rectangular melt region: (xmin, ymin, xmax, ymax)'}

    results = add_center_of_mass(results, 
                       x_center_of_melt, 
                       y_center_of_melt)
     

    
    return results


In [5]:
results = fsm_xarray()
results

[2K

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 3 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024) (3072, 1024) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024)","(3072, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [6]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'5f59661b-9c7a-47ca-bf32-0f226e6d4417': {'version…

In [7]:
results = gridSearch(fsm_xarray, x_center_of_melt = [812500.0, 817500.0], melt_magnitude=[0.2, 2])

  0%|          | 0/4 [00:00<?, ?it/s]

[2K

In [9]:
results

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 8 graph layers,2 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 48.00 MiB 24.00 MiB Shape (3072, 1024, 2, 2) (3072, 1024, 1, 2) Dask graph 2 chunks in 8 graph layers Data type float32 numpy.ndarray",3072  1  2  2  1024,

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 8 graph layers,2 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 16 graph layers,2 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 48.00 MiB 24.00 MiB Shape (3072, 1024, 2, 2) (3072, 1024, 1, 2) Dask graph 2 chunks in 16 graph layers Data type float32 numpy.ndarray",3072  1  2  2  1024,

Unnamed: 0,Array,Chunk
Bytes,48.00 MiB,24.00 MiB
Shape,"(3072, 1024, 2, 2)","(3072, 1024, 1, 2)"
Dask graph,2 chunks in 16 graph layers,2 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'89a3f45c-f6a3-481d-bc12-b61656129b61': {'version…

In [11]:
results = gridSearch(fsm_xarray, 
                     dem_filename = ["rema_subsets/dem_small_2.tif"],
                     melt_magnitude = [0.1, 10],
                     x_center_of_melt = [817500.0],
                     y_center_of_melt = [1.9325e6],
                     melt_width = [5000])
# to save space we can just keep one copy of the dem

def reduce_dims(results):
    dims_to_drop = list(set(results.dims) - set(('x','y','bounds_index')))
    i = {dim_name: 0 for dim_name in dims_to_drop}
    results['dem'] = results.dem.isel(i)
    return results


results = reduce_dims(results)

  0%|          | 0/2 [00:00<?, ?it/s]

[2K

In [12]:
map_water_depth(results)

BokehModel(combine_events=True, render_bundle={'docs_json': {'b37ed428-0041-4c48-98d2-6dcd4d01283a': {'version…

In [29]:
results = gridSearch(fsm_xarray, 
                     dem_filename = ["rema_subsets/dem_small_2.tif"],
                     melt_magnitude = [ 2],
                     x_center_of_melt = np.arange(800000.0, 832000, 10000),
                     y_center_of_melt = np.arange(1.912e6, 1.995e6, 10000),
                     melt_width = [5000])
# to save space we can just keep one copy of the dem
dims_to_drop = list(set(results.dem.dims) - set(('x','y','dem_filename')))
i = {dim_name: 0 for dim_name in dims_to_drop}
results['dem'] = results.dem.isel(i)
results

  0%|          | 0/36 [00:00<?, ?it/s]

[2K

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1)","(3072, 1024, 1)"
Dask graph,1 chunks in 9 graph layers,1 chunks in 9 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.00 MiB 12.00 MiB Shape (3072, 1024, 1) (3072, 1024, 1) Dask graph 1 chunks in 9 graph layers Data type float32 numpy.ndarray",1  1024  3072,

Unnamed: 0,Array,Chunk
Bytes,12.00 MiB,12.00 MiB
Shape,"(3072, 1024, 1)","(3072, 1024, 1)"
Dask graph,1 chunks in 9 graph layers,1 chunks in 9 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,432.00 MiB,108.00 MiB
Shape,"(3072, 1024, 1, 1, 4, 9, 1)","(3072, 1024, 1, 1, 1, 9, 1)"
Dask graph,4 chunks in 80 graph layers,4 chunks in 80 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 432.00 MiB 108.00 MiB Shape (3072, 1024, 1, 1, 4, 9, 1) (3072, 1024, 1, 1, 1, 9, 1) Dask graph 4 chunks in 80 graph layers Data type float32 numpy.ndarray",3072  1  1  1  1024  1  9  4,

Unnamed: 0,Array,Chunk
Bytes,432.00 MiB,108.00 MiB
Shape,"(3072, 1024, 1, 1, 4, 9, 1)","(3072, 1024, 1, 1, 1, 9, 1)"
Dask graph,4 chunks in 80 graph layers,4 chunks in 80 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [30]:
plot = map_water_depth(results)
plot

BokehModel(combine_events=True, render_bundle={'docs_json': {'d0f882e1-9442-4c81-b72f-80da3ad47d5f': {'version…

In [32]:
results.nbytes / 1e9

1.371572328

In [15]:
def gridSearchDask(function, **kwargs):
    """
    Perform a grid search by iterating over all combinations of input parameters and running a given function.

    Parameters:
    function (callable): The function to be executed for each combination of input parameters. This function should return either an xarray dataset, an xarray datarray, or a numpy array.
    **kwargs: Keyword arguments representing the input parameters and their corresponding values.

    Returns:
    xr_unstacked (xarray.core.dataset.Dataset): The concatenated and unstacked xarray dataset containing the results of the grid search.

    Example:
    #### Define a function to be executed for each combination of input parameters
    def my_function(param1, param2):
        ##### Perform some computation using the input parameters
        result = param1 + param2
        return result

    #### Perform a grid search by iterating over all combinations of input parameters
    results = gridSearch(my_function, param1=[1, 2, 3], param2=[4, 5])
    
    """

    # extract the names of the parameters
    p_names = [x for x in kwargs] 

    # extract the values of the parameters
    p_values_list = [x for x in kwargs.values()]
    
    # create a multiIndex from the parameter names and values
    multiIndex = pd.MultiIndex.from_product(p_values_list, names=p_names)

    import dask
    #loop over every conbimation of parameters stored in multiIndex
    xr_out_list = []
    for mi in tqdm(multiIndex):

        # create a dictionary of inputs for the function from the values stored in multiIndex
        inputs = {p_names[x]: mi[x] for x in range(len(p_names))}

        

        # run the function with this combination of inputs
        single_iteration_result = dask.delayed(function)(**inputs)  

         # append the result to a list
        xr_out_list.append(single_iteration_result)

    # concatenate the list of results into a single xarray
    #xr_stacked = xr.concat(xr_out_list, dim='stacked_dim')

    # add the multiIndex to the xarray
    #mindex_coords = xr.Coordinates.from_pandas_multiindex(multiIndex, 'stacked_dim')
    #xr_stacked = xr_stacked.assign_coords(mindex_coords)

    # unstack the xarray - i.e. separate the multiIndex into separate dimensions
    #xr_unstacked = xr_stacked.unstack()

    # convert to a dataset if the result is a data array
    #if isinstance(xr_unstacked, xr.DataArray):
    #    xr_unstacked.name = 'result'
    #    xr_unstacked = xr_unstacked.to_dataset()


    return xr_out_list

[2K

[2K

In [16]:
resultsDask = gridSearchDask(fsm_xarray, x_center_of_melt = [812500.0, 817500.0], melt_magnitude=[0.2, 0.3, 0.4, 0.5, 0.6])
import dask
%time dask.compute(*resultsDask)


  0%|          | 0/10 [00:00<?, ?it/s]

[2KERROR 1: fsm_results/rema_tests/test-3-wtd.tif: fsm_results/rema_tests/test-3-wtd.tif:Failed to allocate memory for to read TIFF directory (0 elements of 12 bytes each)
ERROR 1: fsm_results/rema_tests/test-3-wtd.tif: TIFFReadDirectory:Failed to read directory at offset 8
ERROR 1: TIFFFetchStripThing:IO error during reading of "StripOffsets"
[2KERROR 1: TIFFResetField:fsm_results/rema_tests/test-3-wtd.tif: Could not find tag 273.
[2KERROR 4: Unable to open fsm_results/rema_tests/test-3-wtd.tif to obtain file list.
Key:       fsm_xarray-e74680d6-d73d-4235-80ee-518e71cfdb27
Function:  execute_task
args:      ((<function apply at 0x10464a020>, <function fsm_xarray at 0x16b07eb60>, [], (<class 'dict'>, [['x_center_of_melt', 812500.0], ['melt_magnitude', 0.4]])))
kwargs:    {}
Exception: "CPLE_AppDefinedError(3, 1, 'rema_subsets/water_input_file_3.tif: TIFFReadDirectory:Failed to read directory at offset 25185460')"

[2K

CPLE_AppDefinedError: rema_subsets/water_input_file_3.tif: TIFFReadDirectory:Failed to read directory at offset 25185460

ERROR 4: `rema_subsets/water_input_file_3.tif' not recognized as a supported file format.
libc++abi: terminating due to uncaught exception of type std::runtime_error: Could not open file 'rema_subsets/water_input_file_3.tif' with GDAL!
ERROR 4: `rema_subsets/water_input_file_3.tif' not recognized as a supported file format.
libc++abi: terminating due to uncaught exception of type std::runtime_error: Could not open file 'rema_subsets/water_input_file_3.tif' with GDAL!
[2KERROR 1: TIFFResetField:fsm_results/rema_tests/test-3-label.tif: Could not find tag 273.
[2K

In [10]:
from dask.distributed import Client
client = Client()  # set up local cluster on your laptop
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 24.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:63960,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 24.00 GiB

0,1
Comm: tcp://127.0.0.1:63971,Total threads: 2
Dashboard: http://127.0.0.1:63973/status,Memory: 6.00 GiB
Nanny: tcp://127.0.0.1:63963,
Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-6zv6h18m,Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-6zv6h18m

0,1
Comm: tcp://127.0.0.1:63972,Total threads: 2
Dashboard: http://127.0.0.1:63975/status,Memory: 6.00 GiB
Nanny: tcp://127.0.0.1:63965,
Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-b91u982i,Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-b91u982i

0,1
Comm: tcp://127.0.0.1:63977,Total threads: 2
Dashboard: http://127.0.0.1:63979/status,Memory: 6.00 GiB
Nanny: tcp://127.0.0.1:63967,
Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-ekazp3xr,Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-ekazp3xr

0,1
Comm: tcp://127.0.0.1:63978,Total threads: 2
Dashboard: http://127.0.0.1:63981/status,Memory: 6.00 GiB
Nanny: tcp://127.0.0.1:63969,
Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-6e0och4x,Local directory: /var/folders/kl/3mt9f4qs1559xwy3mr60s7980000gp/T/dask-scratch-space/worker-6e0och4x


2024-02-13 12:46:58,185 - distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/Users/jkingslake/miniconda3/envs/full_py_env/lib/python3.11/site-packages/distributed/protocol/core.py", line 160, in loads
    return msgpack.loads(
           ^^^^^^^^^^^^^^
  File "msgpack/_unpacker.pyx", line 194, in msgpack._cmsgpack.unpackb
  File "/Users/jkingslake/miniconda3/envs/full_py_env/lib/python3.11/site-packages/distributed/protocol/core.py", line 152, in _decode_default
    return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jkingslake/miniconda3/envs/full_py_env/lib/python3.11/site-packages/distributed/protocol/pickle.py", line 96, in loads
    return pickle.loads(x)
           ^^^^^^^^^^^^^^^
  File "/Users/jkingslake/Documents/science/meltwater_routing/BFRN_meltwater/python/notebooks/../scripts/funcs.py", line 209, in <module>
    def gri

KeyboardInterrupt: 

In [26]:
import uuid
filename = "water_input_file_" + uuid.uuid4().hex
filename

'water_input_file_c8bd5d3a94514c8794a86aa268f34aed'

2024-02-15 10:41:50,492 - distributed.worker - ERROR - Scheduler was unaware of this worker 'tcp://127.0.0.1:63977'. Shutting down.
2024-02-15 10:41:50,494 - distributed.worker - ERROR - Scheduler was unaware of this worker 'tcp://127.0.0.1:63971'. Shutting down.
2024-02-15 10:41:50,619 - distributed.core - INFO - Connection to tcp://127.0.0.1:63960 has been closed.
2024-02-15 10:41:50,620 - distributed.core - INFO - Connection to tcp://127.0.0.1:63960 has been closed.
2024-02-15 10:41:52,640 - distributed.nanny - ERROR - Worker process died unexpectedly
