<img SRC="https://avatars2.githubusercontent.com/u/31697400?s=400&u=a5a6fc31ec93c07853dd53835936fd90c44f7483&v=4" WIDTH=125 ALIGN="right">

# Caching

*O.N. Ebbens, Artesia, 2021*

Groundwater flow models are often data-intensive. Execution times can be shortened significantly by caching data. In nlmod we use a specific caching method called [memoization](https://en.wikipedia.org/wiki/Memoization). This notebooks explains how the caching is implemented in `nlmod`.

### Contents<a name="TOC"></a>
1. [Cache directory](#cachedir)
2. [Caching in nlmod](#cachingnlmod)
3. [Checking the cache](#3)
4. [Dicussion](#4)

In [1]:
import matplotlib.pyplot as plt
import flopy
import os
import geopandas as gpd
import xarray as xr

import nlmod

print(f'nlmod version: {nlmod.__version__}')

nlmod version: 0.1.1b


### [1. Cache directory](#TOC)<a name="cachedir"></a>

When you create a model you usually start by assigning a model workspace. This is a directory where model data is stored. The `nlmod.util.get_model_dirs()` function can be used to create a file structure in two steps:
1. The model workspace directory is created if it does not exists yet. 
2. Two subdirectories are created: 'figure' and 'cache'. 

Calling the function below we create the `figdir` and `cachedir` variables with the paths of the subdirectories. In this notebook we will use this `cachedir` to write and read cached data. It is possible to define your own cache directory.

In [2]:
model_ws = 'model5'

# Model directories
figdir, cachedir = nlmod.util.get_model_dirs(model_ws)

print(model_ws)
print(figdir)
print(cachedir)

model5
model5\figure
model5\cache


### [2. Caching in nlmod](#TOC)<a name="cachingnlmod"></a>

In `nlmod` you can use the `get_combined_layer_models` function to obtain a layer model based on `regis`. The function takes some time to complete because the data is read from a server and projected on the desired model grid. Everytime you run this function you have to wait for the process to finish which results in long execution times and an unhealthy number of coffee breaks. This is why we use caching.

In [3]:
layer_model = nlmod.read.regis.get_combined_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                         delr=100., delc=100., use_geotop=False)

The `layer_model` variable above is an `xarray.Dataset`. An `xarray.Dataset` can be read and written easily using the NetCDF file format. To speed up execution times we write the `layer_model` to a NetCDF file so the next time we want to get the `layer_model` we can read the cached NetCDF file instead of downloading a new file.

In [4]:
# write netcdf with layer model data
layer_model.to_netcdf(os.path.join(cachedir, 'combined_layer_test.nc'))

In [5]:
# read netcdf with layer model data
layer_model = xr.open_dataset(os.path.join(cachedir, 'combined_layer_test.nc'))

Reading and writing netcdf files is the main principle behind caching in `nlmod`. It can be used on any function that has an xarray Dataset as output. You can simply use it by specifying a `cachedir` and a `cachename` in the function call.

In [6]:
layer_model = nlmod.read.regis.get_combined_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                         delr=100., delc=100., use_geotop=False,
                                                         cachedir=cachedir, 
                                                         cachename='combined_layer_ds.nc')

When you specify a `cachedir` and a `cachename` in your function call these steps are taken:
1. See if there is a netCDF file with the specified cachename in the specified cache directory. If the file exists go to step 2, otherwise go to step 3.
2. Read the netCDF file and return as an xarray dataset if:
    1. The cached dataset was created using the same function arguments as the current function call. 
    2. The module where the function is defined has not been changed after the cache was created.
3. Call the function, in this case `get_combined_layer_models`, to obtain an xarray dataset. Save this dataset as a netCDF file using the specified cachename in the specified cache directory for next time. Also return the dataset.

### [3. Checking the cache](#TOC)<a name="3"></a>
There are some issues with using cached data. For example: when you modify the model extent, you cannot use the cached data anymore. If we would've simply tried to read the cached data we get notoruously, indecipherable errors. Therefore we can do some standard checks in the `get_cache_netcdf` function. 

When calling the `get_cache_netcdf` function there are 3 optional argument `model_ds`,`check_grid` and `check_time`. The `model_ds` argument is used to obtain information about the desired grid and time discretisation. The `check_grid` and `check_time` arguments both indicate whether to check if the grid and/or time discretisation of the cached grid corresponds to the desired grid. If one of these cheks fails the cached data is not used and a new dataset is cached.

Below you can see what happens if we call the cache function from the previous chapter with a `delc` of 50 instead of 100. When we have the log level set to info, as we did in the beginning of this notebook, we can actually see the outcome of the checks and see that a new dataset is created because the cached data did not correspond to the desired grid.

Note, these checks are not a gaurantee that the cached data will be read exactly as you would expect. There are some cases where it is still difficult to know if the cached data can be used for the current model.

In [8]:
# layer model
layer_model = nlmod.read.regis.get_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                delr=100., delc=50., use_geotop=False,
                                                use_cache=True, fname_netcdf='combined_layer_ds.nc',
                                                cachedir=cachedir)
layer_model

### [4. Techincal](#TOC)<a name="4"></a>

The caching is implemented in the `nlmod` caching module. The `cache_netcdf` function in this module handles most 



. This function is called `nlmod.util.get_cache_netcdf` and can be wrapped around any function that returns an `xarray.Dataset`. The `get_cache_netcdf` needs a few extra arguments for this:
- `use_cache`, to indicate if you want to use the cached file if it is available
- `cachedir`, the directory that is used to cache the data
- `cache_name`, the name of the .nc file of the cached data.
- `get_dataset_func`, this is the function that returns the `xarray.Dataset` that you want to cache.

In the cell below we wrap the cache function around the `get_combined_layer_models` model.


Limitations:
- If you read a netcdf file into an xarray dataset it becomes locked so you cannot overwrite this file or read it again.

### [4. Discussion](#TOC)<a name="4"></a>

caching in its current form has some considerable limitations:
- You store two functions of everything. The original function to obtain an xarray Dataset and the wrapper function that does the caching part. It is confusing and error pron to maintain two nearly identical functions.
- If you wrap the `get_cache_netcdf` around a function which in turn calls the `get_cache_netcdf` you get unexpected results since the `get_cache_netcdf` function does not transfer all parameters to the function it wraps around.

There are many Python packages that offer some kind of caching, such as [beaker](https://beaker.readthedocs.io/en/latest/). Mostly they use a decorator to cache the output of a function give the values of the function arguments. This works seamlessly for functions with hashable function arguments. Unfortunately the xarray dataset is not hashable. 

In order for these caching packages to work we should probably convert the xarray dataset coordinates to some hashable type.

In [92]:
import functools

def _is_valid_cache(func_args_dic, func_args_dic_cache):
    for key, item in func_args_dic.items():
        # check if cache and function call have same argument names
        if not key in func_args_dic_cache.keys():
            valid_cache = False
            print('cache not valid')
            return False

        # check if cache and function call have same argument types
        if type(item)!=type(func_args_dic_cache[key]):
            valid_cache = False
            print('cache not valid')
            return False

        # check if cache and function call have same argument values
        if isinstance(item, (int, float, str, list, tuple)):
            if item != func_args_dic_cache[key]:
                valid_cache = False
                print('cache not valid')
                return False
        elif isinstance(item, np.ndarray):
            if not np.array_equal(item, func_args_dic_cache[key]):
                valid_cache = False
                print('cache not valid')
                return False
        else:
            print('')
            
    return True

    

def cache_netcdf(func):
    
    @functools.wraps(func)
    def decorator(*args, cachedir=None, cachename='test', **kwargs):
        
        if cachedir is None:
            return func(*args, **kwargs)

        if not cachename.endswith('.nc'):
            cachename += '.nc'
        
        fname_cache = os.path.join(cachedir, cachename)
        
        # check if cache is valid
        fname_pickle_cache = fname_cache.replace('.nc','.pklz')
        func_args_dic = {f'arg{i}':args[i] for i in range(len(args))}
        func_args_dic.update(kwargs)
        
        if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache):
            with open(fname_pickle_cache, "rb") as f:
                func_args_dic_cache = pickle.load(f)
                
            if _is_valid_cache(func_args_dic, func_args_dic_cache):
                print('read cache')
                cached_ds = xr.open_dataset(fname_cache)
                return cached_ds
        
        # rerun function to create cache
        result = func(*args, **kwargs)
        
        if isinstance(result, xr.Dataset):
            result.to_netcdf(fname_cache)
            with open(fname_pickle_cache, 'wb') as fpklz:
                pickle.dump(func_args_dic, fpklz)
        else:
            raise TypeError(f'expected xarray Dataset, got {type(result)} instead')
            
        return result
            
    return decorator

Eisen aan de cache:
- Cache opslaan als netcdf bij het model. Op deze manier is de cache ook los van alles in te lezen.
- Cache alleen gebruiken als de functie met precies dezelfde parameters wordt aangeroepen.
- Cache alleen gebruiken als de module waar de functie in staat niet is aangepast sinds de cache is gemaakt.
- Cache overschrijven zodat er niet meerdere cache bestanden ontstaan


In [96]:
model_ds

NameError: name 'model_ds' is not defined

In [93]:
from nlmod.read.regis import *

@cache_netcdf
def get_regis_dataset(extent, delr, delc, botm_layer=b'AKc'):#, cachedir=None, cachename='regis.nc'):
    """get a regis dataset projected on the modelgrid.

    Parameters
    ----------
    extent : list, tuple or np.array
        desired model extent (xmin, xmax, ymin, ymax)
    delr : int or float,
        cell size along rows, equal to dx
    delc : int or float,
        cell size along columns, equal to dy
    botm_layer : binary str, optional
        regis layer that is used as the bottom of the model. This layer is
        included in the model. the Default is b'AKc' which is the bottom
        layer of regis. call nlmod.regis.get_layer_names() to get a list of
        regis names.

    Returns
    -------
    regis_ds : xarray dataset
        dataset with regis data projected on the modelgrid.
    """
    # check extent
    extent2, nrow, ncol = fit_extent_to_regis(extent, delr, delc)
    for coord1, coord2 in zip(extent, extent2):
        if coord1 != coord2:
            raise ValueError(
                'extent not fitted to regis please fit to regis first, use the nlmod.regis.fit_extent_to_regis function')

    # get local regis dataset
    regis_url = 'http://www.dinodata.nl:80/opendap/REGIS/REGIS.nc'

    regis_ds_raw = xr.open_dataset(regis_url, decode_times=False)

    # set x and y dimensions to cell center
    regis_ds_raw['x'] = regis_ds_raw.x_bounds.mean('bounds')
    regis_ds_raw['y'] = regis_ds_raw.y_bounds.mean('bounds')

    # slice extent
    regis_ds_raw = regis_ds_raw.sel(x=slice(extent[0], extent[1]),
                                    y=slice(extent[2], extent[3]))

    # slice layers
    if isinstance(botm_layer, str):
        botm_layer = botm_layer.encode('utf-8')

    layer_no = np.where((regis_ds_raw.layer == botm_layer).values)[0][0]
    regis_ds_raw = regis_ds_raw.sel(layer=regis_ds_raw.layer[:layer_no + 1])

    # slice data vars
    regis_ds_raw = regis_ds_raw[['top', 'bottom', 'kD', 'c', 'kh', 'kv']]
    regis_ds_raw = regis_ds_raw.rename_vars({'bottom': 'bot'})

    # rename layers
    regis_ds_raw = regis_ds_raw.rename({'layer': 'layer_old'})
    regis_ds_raw.coords['layer'] = regis_ds_raw.layer_old.astype(
        str)  # could also use assign_coords
    regis_ds_raw2 = regis_ds_raw.swap_dims({'layer_old': 'layer'})

    # convert regis dataset to grid
    logger.info('resample regis data to structured modelgrid')
    regis_ds = mdims.resample_dataset_to_structured_grid(regis_ds_raw2, 
                                                         extent,
                                                         delr, delc)
    regis_ds.attrs['extent'] = extent
    regis_ds.attrs['delr'] = delr
    regis_ds.attrs['delc'] = delc
    regis_ds.attrs['gridtype'] = 'structured'

    for datavar in regis_ds:
        regis_ds[datavar].attrs['source'] = 'REGIS'
        regis_ds[datavar].attrs['url'] = regis_url
        regis_ds[datavar].attrs['date'] = dt.datetime.now().strftime('%Y%m%d')

    return regis_ds


In [22]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 50.)

In [94]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 100., cachedir='cache')

read cache


In [15]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 100.)

In [49]:
func_args_dic

{'arg0': 1, 'arg1': 2, 'arg2': 3, 'botm_layer': b'AKc', 'test': 'fiets'}

In [50]:
pickle.dump(func_args_dic, 'test.pklz')

TypeError: file must have a 'write' attribute

In [16]:
a

In [38]:
from joblib import Memory
import time
location = 'cache'
memory = Memory(location, verbose=0)

@memory.cache
def get_result(x):
    for i in range(5):
        print(i)
        time.sleep(1)
    
    result = str(x*19)
    
    return result



In [35]:
os.mkdir('cache2')

In [36]:
memory = Memory('cache2', verbose=0)

In [39]:
get_result(10)

'190'

In [32]:
get_result = memory.cache(get_result)
get_result(10)

0
1
2
3
4


'190'

In [33]:
memory.clear()



In [22]:
def cache_file(func):
    
    def inner(*args, **kwargs):
        #get hash
        key = str((args, tuple(sorted(kwargs.items()))))
        
        if os.path.exists(key):
            with open(key, 'r') as fo:
                return fo.read()
        else:
            func(*args, **kwargs)
            result = func(*args, **kwargs)
            with open(key, 'w') as fo:
                fo.write(result)
            return result
            
    return inner

@cache_file
def get_result(x, fname):
    for i in range(5):
        print(i)
        time.sleep(1)
    
    result = str(x*19)
    
    return result

In [28]:
get_result(8, fname='test.txt')

(8,)
{'fname': 'test.txt'}
((8,), (('fname', 'test.txt'),))


'152'

In [29]:
%time get_result(9, fname='test.txt')

(9,)
{'fname': 'test.txt'}
((9,), (('fname', 'test.txt'),))
Wall time: 2 ms


'171'

In [5]:
get_resulttime

    
    
    

In [None]:
get_result()

In [None]:
with open(fname, 'w') as fo:
        fo.write(str(x))
    