In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import xarray as xr
from colorcet import bmw, coolwarm
import zipfile
import os,requests
import numpy as np
from tqdm import tqdm
import rioxarray
from functools import partial
import numpy as np
from matplotlib import cm

In [2]:
from xpublish.routers import XYZFactory
from xpublish.utils.ows import DataShader, MatplotLib

In [4]:
url = 'https://edcintl.cr.usgs.gov/downloads/sciweb1/shared/hydrosheds/sa_30s_zip_grid/sa_acc_30s_grid.zip'
filename = os.path.basename(url)
name = filename[:filename.find('_grid')]
adffile = name + '/' + name + '/w001001.adf'

if not os.path.exists(adffile):
    r = requests.get(url, stream=True)
    with open(filename, 'wb') as f:
        total_length = int(r.headers.get('content-length'))
        for chunk in tqdm(r.iter_content(chunk_size=1024), total=(total_length/1024) + 1):
            if chunk:
                f.write(chunk)
                f.flush()
    zip = zipfile.ZipFile(filename)
    zip.extractall('.')

In [7]:
ds = xr.open_rasterio(adffile)

In [8]:
ds = ds.isel(band=0).drop_vars("band")
attrs = ds.attrs
nodata = ds.nodatavals

In [12]:
dem = ds.to_dataset(name="dem")
dem = dem.where(dem != nodata,np.nan)
dem = dem.chunk({"x":500,"y":500})

## Renderers

There are two renders available.

A renderer is a specialised subclass of the Renderer class:

```python
class Renderer:
    def __init__(
        self, interpolation={}, aggregation={}, normalization={}, color_mapping={}
    ):
        self.interp_params = interpolation
        self.agg_params = aggregation
        self.norm_params = normalization
        self.cm_params = color_mapping

    def interpolation(self, arr):
        return arr

    def aggregation(self, arr):
        return arr

    def normalization(self, arr):
        return arr

    def color_mapping(self, arr):
        return arr
```



In [23]:
renderer = DataShader(
                      aggregation={"upsample_method": "linear"}, 
                      color_mapping={ 
                                "cmap": colorcet.rainbow,
                                "how": "log", 
                                "span": [float(np.nanmin(dem["dem"].values)), 
                                         float(np.nanmax(dem["dem"].values))], 
                                "alpha": 255 }
                        )

In [41]:
renderer = MatplotLib(
              normalization={"method": cm.colors.LogNorm,
                             "method_kwargs":{"vmin":np.nanmin,
                                              "vmax":np.nanmax}}, # min max calculated on individual tile
              color_mapping={"cm":"jet"}
                        )

In [20]:
renderer = MatplotLib(
              normalization={"method": "LogNorm", 
                             "method_kwargs":{"vmin":float(np.nanmin(ds2["dem"].values)),
                                              "vmax":float(np.nanmax(ds2["dem"].values))}}, 
              color_mapping={"cm":"Blues"}
                        )  

In [56]:
renderer = MatplotLib(
              normalization={"method": cm.colors.LogNorm, 
                             "method_kwargs":{"vmin":float(np.nanmin(ds2["dem"].values)),
                                              "vmax":float(np.nanmax(ds2["dem"].values))}}, 
              color_mapping={"cm":"jet", 
                             "cm_kwargs":{"alpha":0.5}}
                        ) 

## Transformers:

Apply transformation to the individial tiles, follows the example from https://github.com/davidbrochart/xarray_leaflet/blob/master/examples/dynamic.ipynb

In [25]:
def transform1(array, *args, **kwargs):
    tile_width = 256
    tile_height = 256
    ny, nx = array.shape
    wx = nx // (tile_width // 2)
    wy = ny // (tile_height // 2)
    dim = {}
    if wx > 1:
        dim['x'] = wx
    if wy > 1:
        dim['y'] = wy
    array = array.coarsen(**dim, boundary='pad')
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        array = xr.core.rolling.DataArrayCoarsen.max(array)
    return array

In [26]:
def transform2(array, *args, **kwargs):
    radius = 2
    circle = np.zeros((2*radius+1, 2*radius+1)).astype('uint8')
    y, x = np.ogrid[-radius:radius+1,-radius:radius+1]
    index = x**2 + y**2 <= radius**2
    circle[index] = 1
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        array = np.sqrt(array)
    array.data = scipy.ndimage.maximum_filter(array, footprint=circle)
    return array

## Setting the CRS

In [13]:
crs_epsg = 4326

## Instantiante the xyz router

In [28]:
xyz_router = XYZFactory(crs_epsg=crs_epsg,
                        renderer=renderer,
                        transformers=[transform1, transform2])

In [29]:
dem.rest(routers=[xyz_router])

<xpublish.rest.RestAccessor at 0x7f345f84a940>

In [30]:
import nest_asyncio 
nest_asyncio.apply()

In [31]:
ds2.rest.serve(port=9000,host='127.0.0.1',log_level='info')