In [None]:
import rioxarray
import rasterio
from rasterio.warp import Resampling, calculate_default_transform, reproject
import datashader as ds
from datashader.transfer_functions import shade, stack, set_background
import dask.array as da
import xarray as xr
import numpy as np
import sys
from xrspatial.utils import ArrayTypeFunctionMapping
import math
from datashader.colors import Elevation
from xrspatial import hillshade

In [None]:
def shade_elev_hs(elev_3D, band=1):
    return stack(shade(elev_3D[band-1], cmap=Elevation),
                 shade(hillshade(elev_3D[band-1]), cmap=['blue', 'grey'],
                       how='eq_hist', min_alpha=0.2, alpha=0.8))

In [None]:
def reproject_nparr(source, src_transform, src_crs, dst_crs, src_bounds, **kwargs):
    _, height, width = source.shape
    
    dst_transform, dst_width, dst_height =\
    calculate_default_transform(
        src_crs,
        dst_crs,
        width,
        height,
        *src_bounds
    )
    
    if 'src_nodata' in list(kwargs.keys()):
        src_nodata = kwargs['src_nodata']
    else:
        src_nodata = None
    if 'dst_nodata' in list(kwargs.keys()):
        dst_nodata = kwargs['dst_nodata']
    else:
        dst_nodata = None
    
    destination = np.zeros((1, dst_height, dst_width), dtype=np.uint16)
    
    destination, transform =\
    reproject(source,
              destination=destination,
              src_transform=src_transform,
              src_crs=src_crs,
              src_nodata=src_nodata,
              dst_transform=dst_transform,
              dst_crs=dst_crs,
              dst_nodata=dst_nodata,
              **kwargs)
    
    return (destination, transform)

In [None]:
def _numpy_reproject(arr, dst_crs, **kwargs):
    data = arr.data
    src_transform = arr.rio.transform()
    src_crs = arr.rio.crs
    _, src_height, src_width = data.shape
    src_bounds = rasterio.transform.array_bounds(src_height, src_width, src_transform)
    
    reprojected_data, dst_transform = reproject_nparr(data,
                                                      src_transform,
                                                      src_crs,
                                                      dst_crs,
                                                      src_bounds,
                                                      **kwargs)
    
    return (reprojected_data, dst_transform)
    
    
    return reprojected_da


def _cupy_reproject():
    raise NotImplementedError('cupy not implemented yet')


def _dask_cupy_reproject():
    raise NotImplementedError('dask cupy not implemented yet')


def _dask_reproject():
    pass

In [None]:
def xrs_reproject(arr, dst_crs, **kwargs):
    mapper = ArrayTypeFunctionMapping(numpy_func=_numpy_reproject,
                                      cupy_func=_cupy_reproject,
                                      dask_cupy_func=_dask_cupy_reproject,
                                      dask_func=_dask_reproject)
    reprojected_data, dst_transform = mapper(arr)(arr, dst_crs, **kwargs)
    
    reprojected_data = np.rot90(np.rot90(reprojected_data))
    
    _, dst_height, dst_width = reprojected_data.shape
    left, bottom, right, top = dst_bounds =\
    rasterio.transform.array_bounds(dst_height, dst_width, dst_transform)
    
    xres, yres = (right - left)/dst_width, (top - bottom)/dst_height
    xoff, yoff = dst_transform.xoff, dst_transform.yoff
    
    dst_xs = np.arange(dst_width) * xres + (xoff + xres/2)
    dst_ys = np.arange(dst_height) * yres + (yoff + yres/2)
    
    xs_da = xr.DataArray(dst_xs, dims=('x'))
    xs_da.coords['x'] = dst_xs
    ys_da = xr.DataArray(dst_ys, dims=('y'))
    ys_da.coords['y'] = dst_ys
    
    reprojected_da = xr.DataArray(reprojected_data, dims=('band', 'y', 'x'))
    reprojected_da.coords['band'] = arr.coords['band']
    reprojected_da.coords['x'] = xs_da
    reprojected_da.coords['y'] = ys_da
    reprojected_da.rio.write_crs(dst_crs, inplace=True)
    reprojected_da.rio.write_transform(inplace=True)
    
    return reprojected_da

In [None]:
elev_da = rioxarray.open_rasterio('/Users/ls/Downloads/elevation.tif')
reprojected = xrs_reproject(elev_da, 'EPSG:3857')
shade_elev_hs(reprojected)