<img SRC="https://avatars2.githubusercontent.com/u/31697400?s=400&u=a5a6fc31ec93c07853dd53835936fd90c44f7483&v=4" WIDTH=125 ALIGN="right">

# Caching

*O.N. Ebbens, Artesia, 2021*

Groundwater flow models are often data-intensive. Execution times can be shortened significantly by caching data. In nlmod we use a specific caching method called [memoization](https://en.wikipedia.org/wiki/Memoization). This notebooks explains how the caching is implemented in `nlmod`. The first three chapters explain how to use the caching in nlmod. The last chapter contains more technical details on the implementation and limitations of caching in nlmod.

### Contents<a name="TOC"></a>
1. [Cache directory](#cachedir)
2. [Caching in nlmod](#cachingnlmod)
3. [Checking the cache](#3)
4. [Dicussion](#4)

In [1]:
import matplotlib.pyplot as plt
import flopy
import os
import geopandas as gpd
import xarray as xr
import logging

import nlmod

# toon informatie bij het aanroepen van functies
logging.basicConfig(level=logging.INFO)
print(f'nlmod version: {nlmod.__version__}')

nlmod version: 0.1.1b


### [1. Cache directory](#TOC)<a name="cachedir"></a>

When you create a model you usually start by assigning a model workspace. This is a directory where model data is stored. The `nlmod.util.get_model_dirs()` function can be used to create a file structure in two steps:
1. The model workspace directory is created if it does not exist yet. 
2. Two subdirectories are created: 'figure' and 'cache'. 

Calling the function below we create the `figdir` and `cachedir` variables with the paths of the subdirectories. In this notebook we will use this `cachedir` to write and read cached data. It is possible to define your own cache directory.

In [2]:
model_ws = 'model5'

# Model directories
figdir, cachedir = nlmod.util.get_model_dirs(model_ws)

print(model_ws)
print(figdir)
print(cachedir)

model5
model5\figure
model5\cache


### [2. Caching in nlmod](#TOC)<a name="cachingnlmod"></a>

In `nlmod` you can use the `get_combined_layer_models` function to obtain a layer model based on `regis`.

In [3]:
layer_model = nlmod.read.regis.get_combined_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                         delr=100., delc=100., use_geotop=False)

INFO:nlmod.read.regis:redefining current extent: [95000.0, 105000.0, 494000.0, 500000.0], fit to regis raster
INFO:nlmod.read.regis:new extent is [95000.0, 105000.0, 494000.0, 500000.0] model has 60 rows and 100 columns
INFO:nlmod.read.regis:resample regis data to structured modelgrid


KeyboardInterrupt: 

As you may notice, this function takes some time to complete because the data is downloaded and projected on the desired model grid. Everytime you run this function you have to wait for the process to finish which results in an unhealthy number of coffee breaks. This is why we use caching. To store our cache we use netCDF files. The `layer_model` variable is an `xarray.Dataset`. You can read/write an `xarray.Dataset` to/from a NetCDF file using the code below.

In [18]:
# write netcdf with layer model data
layer_model.to_netcdf(os.path.join(cachedir, 'layer_test.nc'))

In [19]:
# read netcdf with layer model data
layer_model = xr.open_dataset(os.path.join(cachedir, 'layer_test.nc'))

Reading and writing netcdf files is the main principle behind caching in `nlmod`. We write the `layer_model` to a NetCDF file when we call the `get_combined_layer_models` function for the first time. The next time we call the function we can read the cached NetCDF file instead. This reduces exuction time signficantly. You can simply use this caching abilities by specifying a `cachedir` and a `cachename` in the function call.

In [20]:
layer_model = nlmod.read.regis.get_combined_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                         delr=100., delc=100., use_geotop=False,
                                                         cachedir=cachedir, 
                                                         cachename='combined_layer_ds.nc')

INFO:nlmod.cache:using cached data -> combined_layer_ds.nc


This type of caching is applied to a number of functions in nlmod that have an xarray dataset as output. When you call these functions using the `cachedir` and `cachename` arguments these steps are taken:
1. See if there is a netCDF file with the specified cachename in the specified cache directory. If the file exists go to step 2, otherwise go to step 3.
2. Read the netCDF file and return as an xarray dataset if:
    1. The cached dataset was created using the same function arguments as the current function call. 
    2. The module where the function is defined has not been changed after the cache was created.
3. Run the function to obtain an xarray dataset. Save this dataset as a netCDF file, using the specified cachename and cache directory, for next time. Also return the dataset.

The following functions use the caching as described above:
- nlmod.read.regis.get_combined_layer_models
- nlmod.read.rws.surface_water_to_model_dataset
- nlmod.read.knmi.add_knmi_to_model_dataset
- nlmod.read.jarkus.find_sea_cells
- nlmod.read.jarkus.bathymetry_to_model_dataset
- nlmod.read.geotop.get_geotop_dataset
- nlmod.read.ahn.get_ahn_at_grid

### [3. Checking the cache](#TOC)<a name="3"></a>
One of the steps in the caching process (step 2A) is to check if the cache was created using the same function arguments as the current function call. This check has some limitations:
- Only function arguments with these types are checked: int, float, bool, str, bytes, list, tuple, dict, numpy.ndarray, xarray.DataArray and xarray.Dataset. If a function argument has a different type the cache is never used. In time more types can be added to the check.
- If one of the function arguments is an xarray Dataset the check is somewhat different. In this case it is checked if the dataset has the same dimensions and coordinates as the cached netcdf file. The data variables of the dataset can differ.
- It is not possible to have more than one function argument with the type xarray Dataset due to the different way of checking datasets. If more than one xarray dataset is given the cache decoraters raises a TypeError.
- If one of the function arguments is a filepath of type str. We only check if the cached filepath is the same as the current filepath. We do not check if any changes were made to the file after the cache was created.

You can test this yourself by running the function below a few times with different function arguments. The logs give some extra information about using the cache or not.

In [None]:
# layer model
layer_model = nlmod.read.regis.get_combined_layer_models(extent=[95000.0, 105000.0, 494000.0, 500000.0],
                                                         delr=50., delc=100., use_geotop=False,
                                                         cachename='combined_layer_ds.nc',
                                                         cachedir=cachedir)
layer_model

INFO:nlmod.read.regis:redefining current extent: [95000.0, 105000.0, 494000.0, 500000.0], fit to regis raster
INFO:nlmod.read.regis:new extent is [95000.0, 105000.0, 494000.0, 500000.0] model has 60 rows and 200 columns
INFO:nlmod.read.regis:resample regis data to structured modelgrid


### clearing the cache

Sometimes you want to get rid of all the cached files to free disk space or to support your minimalistic lifestyle. You can use the `clear_cache` function to clear all cached files in a specific cache directory.

In [4]:
nlmod.cache.clear_cache(cachedir)

this will remove all cached files in {cachedir} are you sure [Y/N] y


### [4. Technical](#TOC)<a name="4"></a>

The caching is implemented in the `nlmod` caching module. The `cache_netcdf` decorator function handles most of the logic. The check on function arguments (step 2A) is done by storing all function arguments in a dictionary and saving as a pickle.

Limitations:
- If you read a netcdf file into an xarray dataset it becomes locked so you cannot overwrite this file or read it again. You cannot overwrite a netcdf file that is already open.
- All function arguments are pickled and saved together with the netcdf file. If the function arguments use a lot of memory this process can be become slow. This should be taken into account when you decide to use the cache decorator.
- Function arguments that cannot be pickled result in an error.
- If one of the function arguments is an xarray Dataset we only check if the dataset has the same dimensions and coordinates as the cached netcdf file. There is no check on the variables (DataArrays) in the dataset because it would simply take too much time to check all the variables in the dataset. Also, most of the time it is not necesary to check all the variables as they are not used to create the cached file. However there is one exception where a data variable is used to create the cached file. The `nlmod.read.jarkus.bathymetry_to_model_dataset` uses the 'Northsea' DataArray to create a bathymetry dataset. When we simply access the 'Northsea' DataArray using `model_ds['Northsea']` in the `bathymetry_to_model_dataset` function there would be no check if the 'Northsea' DataArray that was used to create the cache is the same as the 'Northsea' DataArray in the current function call. The current solution for this is to make the 'Northsea' DataArray a separate function argument in the `bathymetry_to_model_dataset` function.

In [92]:
import functools

def _is_valid_cache(func_args_dic, func_args_dic_cache):
    for key, item in func_args_dic.items():
        # check if cache and function call have same argument names
        if not key in func_args_dic_cache.keys():
            valid_cache = False
            print('cache not valid')
            return False

        # check if cache and function call have same argument types
        if type(item)!=type(func_args_dic_cache[key]):
            valid_cache = False
            print('cache not valid')
            return False

        # check if cache and function call have same argument values
        if isinstance(item, (int, float, str, list, tuple)):
            if item != func_args_dic_cache[key]:
                valid_cache = False
                print('cache not valid')
                return False
        elif isinstance(item, np.ndarray):
            if not np.array_equal(item, func_args_dic_cache[key]):
                valid_cache = False
                print('cache not valid')
                return False
        else:
            print('')
            
    return True

    

def cache_netcdf(func):
    
    @functools.wraps(func)
    def decorator(*args, cachedir=None, cachename='test', **kwargs):
        
        if cachedir is None:
            return func(*args, **kwargs)

        if not cachename.endswith('.nc'):
            cachename += '.nc'
        
        fname_cache = os.path.join(cachedir, cachename)
        
        # check if cache is valid
        fname_pickle_cache = fname_cache.replace('.nc','.pklz')
        func_args_dic = {f'arg{i}':args[i] for i in range(len(args))}
        func_args_dic.update(kwargs)
        
        if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache):
            with open(fname_pickle_cache, "rb") as f:
                func_args_dic_cache = pickle.load(f)
                
            if _is_valid_cache(func_args_dic, func_args_dic_cache):
                print('read cache')
                cached_ds = xr.open_dataset(fname_cache)
                return cached_ds
        
        # rerun function to create cache
        result = func(*args, **kwargs)
        
        if isinstance(result, xr.Dataset):
            result.to_netcdf(fname_cache)
            with open(fname_pickle_cache, 'wb') as fpklz:
                pickle.dump(func_args_dic, fpklz)
        else:
            raise TypeError(f'expected xarray Dataset, got {type(result)} instead')
            
        return result
            
    return decorator

Eisen aan de cache:
- Cache opslaan als netcdf bij het model. Op deze manier is de cache ook los van alles in te lezen.
- Cache alleen gebruiken als de functie met precies dezelfde parameters wordt aangeroepen.
- Cache alleen gebruiken als de module waar de functie in staat niet is aangepast sinds de cache is gemaakt.
- Cache overschrijven zodat er niet meerdere cache bestanden ontstaan


In [96]:
model_ds

NameError: name 'model_ds' is not defined

In [93]:
from nlmod.read.regis import *

@cache_netcdf
def get_regis_dataset(extent, delr, delc, botm_layer=b'AKc'):#, cachedir=None, cachename='regis.nc'):
    """get a regis dataset projected on the modelgrid.

    Parameters
    ----------
    extent : list, tuple or np.array
        desired model extent (xmin, xmax, ymin, ymax)
    delr : int or float,
        cell size along rows, equal to dx
    delc : int or float,
        cell size along columns, equal to dy
    botm_layer : binary str, optional
        regis layer that is used as the bottom of the model. This layer is
        included in the model. the Default is b'AKc' which is the bottom
        layer of regis. call nlmod.regis.get_layer_names() to get a list of
        regis names.

    Returns
    -------
    regis_ds : xarray dataset
        dataset with regis data projected on the modelgrid.
    """
    # check extent
    extent2, nrow, ncol = fit_extent_to_regis(extent, delr, delc)
    for coord1, coord2 in zip(extent, extent2):
        if coord1 != coord2:
            raise ValueError(
                'extent not fitted to regis please fit to regis first, use the nlmod.regis.fit_extent_to_regis function')

    # get local regis dataset
    regis_url = 'http://www.dinodata.nl:80/opendap/REGIS/REGIS.nc'

    regis_ds_raw = xr.open_dataset(regis_url, decode_times=False)

    # set x and y dimensions to cell center
    regis_ds_raw['x'] = regis_ds_raw.x_bounds.mean('bounds')
    regis_ds_raw['y'] = regis_ds_raw.y_bounds.mean('bounds')

    # slice extent
    regis_ds_raw = regis_ds_raw.sel(x=slice(extent[0], extent[1]),
                                    y=slice(extent[2], extent[3]))

    # slice layers
    if isinstance(botm_layer, str):
        botm_layer = botm_layer.encode('utf-8')

    layer_no = np.where((regis_ds_raw.layer == botm_layer).values)[0][0]
    regis_ds_raw = regis_ds_raw.sel(layer=regis_ds_raw.layer[:layer_no + 1])

    # slice data vars
    regis_ds_raw = regis_ds_raw[['top', 'bottom', 'kD', 'c', 'kh', 'kv']]
    regis_ds_raw = regis_ds_raw.rename_vars({'bottom': 'bot'})

    # rename layers
    regis_ds_raw = regis_ds_raw.rename({'layer': 'layer_old'})
    regis_ds_raw.coords['layer'] = regis_ds_raw.layer_old.astype(
        str)  # could also use assign_coords
    regis_ds_raw2 = regis_ds_raw.swap_dims({'layer_old': 'layer'})

    # convert regis dataset to grid
    logger.info('resample regis data to structured modelgrid')
    regis_ds = mdims.resample_dataset_to_structured_grid(regis_ds_raw2, 
                                                         extent,
                                                         delr, delc)
    regis_ds.attrs['extent'] = extent
    regis_ds.attrs['delr'] = delr
    regis_ds.attrs['delc'] = delc
    regis_ds.attrs['gridtype'] = 'structured'

    for datavar in regis_ds:
        regis_ds[datavar].attrs['source'] = 'REGIS'
        regis_ds[datavar].attrs['url'] = regis_url
        regis_ds[datavar].attrs['date'] = dt.datetime.now().strftime('%Y%m%d')

    return regis_ds


In [22]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 50.)

In [94]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 100., cachedir='cache')

read cache


In [15]:
a = get_regis_dataset([95000.0, 105000.0, 494000.0, 500000.0], 100., 100.)

In [49]:
func_args_dic

{'arg0': 1, 'arg1': 2, 'arg2': 3, 'botm_layer': b'AKc', 'test': 'fiets'}

In [50]:
pickle.dump(func_args_dic, 'test.pklz')

TypeError: file must have a 'write' attribute

In [16]:
a

In [38]:
from joblib import Memory
import time
location = 'cache'
memory = Memory(location, verbose=0)

@memory.cache
def get_result(x):
    for i in range(5):
        print(i)
        time.sleep(1)
    
    result = str(x*19)
    
    return result



In [35]:
os.mkdir('cache2')

In [36]:
memory = Memory('cache2', verbose=0)

In [39]:
get_result(10)

'190'

In [32]:
get_result = memory.cache(get_result)
get_result(10)

0
1
2
3
4


'190'

In [33]:
memory.clear()



In [22]:
def cache_file(func):
    
    def inner(*args, **kwargs):
        #get hash
        key = str((args, tuple(sorted(kwargs.items()))))
        
        if os.path.exists(key):
            with open(key, 'r') as fo:
                return fo.read()
        else:
            func(*args, **kwargs)
            result = func(*args, **kwargs)
            with open(key, 'w') as fo:
                fo.write(result)
            return result
            
    return inner

@cache_file
def get_result(x, fname):
    for i in range(5):
        print(i)
        time.sleep(1)
    
    result = str(x*19)
    
    return result

In [28]:
get_result(8, fname='test.txt')

(8,)
{'fname': 'test.txt'}
((8,), (('fname', 'test.txt'),))


'152'

In [29]:
%time get_result(9, fname='test.txt')

(9,)
{'fname': 'test.txt'}
((9,), (('fname', 'test.txt'),))
Wall time: 2 ms


'171'

In [5]:
get_resulttime

    
    
    

In [None]:
get_result()

In [None]:
with open(fname, 'w') as fo:
        fo.write(str(x))
    