In [None]:
import rioxarray
import rasterio
from rasterio.warp import Resampling, calculate_default_transform, reproject
import datashader as ds
from datashader.transfer_functions import shade, stack, set_background
import dask.array as da
import xarray as xr
import numpy as np
import sys
from xrspatial.utils import ArrayTypeFunctionMapping
import math
from datashader.colors import Elevation
from xrspatial import hillshade

In [None]:
def shade_elev_hs(elev_3D, band=1):
    if not isinstance(elev_3D, xr.DataArray):
        elev_3D = xr.DataArray(elev_3D, dims=('band', 'y', 'x'))
    return stack(shade(elev_3D[band-1], cmap=Elevation),
                 shade(hillshade(elev_3D[band-1]), cmap=['blue', 'grey'],
                       how='eq_hist', min_alpha=0.2, alpha=0.8))

In [None]:
def add_coords_crs_on_arr(arr: xr.DataArray, coords_list, crs):
    for coord in coords_list:
        coord_name = coord[0]
        coord_arr = coord[1]
        arr.coords[coord_name] = coord_arr
    arr.rio.set_crs(crs, inplace=True)
    arr.rio.write_crs(crs, inplace=True)
    arr.rio.write_transform(inplace=True)
    return arr

In [None]:
def reproject_coords(src_shape, src_coords, src_crs, dst_crs):
    src_empty = np.empty(src_shape)
    src_empty_da = xr.DataArray(src_empty, dims=('band', 'y', 'x'))
    src_empty_da = add_coords_crs_on_arr(src_empty_da, src_coords, src_crs)
    dst_empty_da = src_empty_da.rio.reproject(dst_crs)
    dst_coords = gen_coords_list(dst_empty_da)
    return dst_coords

In [None]:
def gen_coords_list(src_da):
    src_coords = [[coord, src_da.coords[coord]] for coord in src_da.coords if coord in ('band', 'y', 'x')]
    return src_coords

In [None]:
def gen_chunk_locations(numblocks):
    chunk_locations = []
    for i in range(numblocks[0]):
        for j in range(numblocks[1]):
            for k in range(numblocks[2]):
                chunk_locations.append([i, j, k])
    return chunk_locations

In [None]:
def list_chunk_shapes(chunks, chunk_locations):
    shapes = []
    chunks = [list(chnk) for chnk in chunks]
    for loc in chunk_locations:
        shape = []
        for i in range(len(chunks)):
            shape.append(chunks[i][loc[i]])
        shapes.append(shape)
    return shapes

In [None]:
def gen_array_locations(chunk_locations, shapes):
    array_locations = []
    for i in range(len(shapes)):
        array_locations.append([])
        for j in range(3):
            if i > 0:
                sum_previous = 0
                for k in range(chunk_locations[i][j]):
                    sum_previous += (array_locations[k][j][1] - array_locations[k][j][0])
                loc = [sum_previous, sum_previous + shapes[i][j]]
            else:
                loc = [0, shapes[i][j]]
            array_locations[i].append(loc)
    return array_locations

In [None]:
def calc_chunksizes_rio(src_shape, src_coords, chunks, numblocks, src_crs, dst_crs):
    chunk_locations = gen_chunk_locations(numblocks)
    shapes = list_chunk_shapes(chunks, chunk_locations)
    array_locations = gen_array_locations(chunk_locations, shapes)
    chunkshapes = []
    shape_range = tuple(range(len(src_shape)))
    for i in range(len(chunks[0])):
        for j in range(len(chunks[1])):
            for k in range(len(chunks[2])):
                shape = (chunks[0][i], chunks[1][j], chunks[2][k])
                chunkshapes.append(shape)
    new_shapes = []
    for i in range(len(chunkshapes)):
        empty_np = np.empty(chunkshapes[i])
        empty_da = xr.DataArray(empty_np, dims=('band', 'y', 'x'))
        block_coords = []
        for j in range(len(src_coords)):
            coord_name = src_coords[j][0]
            coord_arr = src_coords[j][1]
            if j == 0:
                block_coord = coord_arr
            else:
                array_location = array_locations[i][j]
                start = array_location[0]
                end = array_location[1] 
                block_coord = coord_arr[start:end]
            block_coords.append([coord_name, block_coord])
            empty_da.coords[coord_name] = block_coord
        empty_da.rio.set_crs(src_crs, inplace=True)
        empty_da.rio.write_crs(src_crs, inplace=True)
        empty_da.rio.write_transform(inplace=True)
        reprojected = empty_da.rio.reproject(dst_crs)
        new_shape = reprojected.data.shape
        new_shapes.append(new_shape)
    new_chunksizes = list(zip(*new_shapes))
    new_chunks = []
    for i in range(len(numblocks)):
        num = numblocks[i]
        new_chunks.append(new_chunksizes[i][-num:])
    return new_chunks

In [None]:
def _block_reproject(data, src_coords, src_crs, dst_crs, block_info=None):
    if block_info is not None:
        data_arr = xr.DataArray(data, dims=('band', 'y', 'x'))
        in_arr_loc = block_info[0]['array-location']
        block_coords = []
        for i in range(len(in_arr_loc)):
            dim_start = in_arr_loc[i][0]
            dim_end = in_arr_loc[i][1]
            dim_coord_name = src_coords[i][0]
            dim_coord = src_coords[i][1]
            block_dim_coord = dim_coord[dim_start:dim_end]
            block_coords.append((dim_coord_name, block_dim_coord))
        data_arr = add_coords_crs_on_arr(data_arr, block_coords, src_crs)
        reprojected_block_da = data_arr.rio.reproject(dst_crs)
        reprojected_block = reprojected_block_da.data
        return reprojected_block

In [None]:
def _cupy_reproject():
    raise NotImplementedError('cupy is not supported yet; please use numpy or dask')
    
def _dask_cupy_reproject():
    raise NotImplementedError('dask cupy not implemented yet; please use numpy or dask')

In [None]:
def _numpy_reproject(arr, dst_crs, **kwargs):
    return arr.rio.reproject(dst_crs, **kwargs)

In [None]:
def _dask_reproject(arr, dst_crs, **kwargs):
    arr.data = arr.data.astype(np.uint16)
    src_coords = gen_coords_list(arr)
    src_crs = arr.rio.crs
    chunks = calc_chunksizes_rio(arr.data.shape, src_coords, arr.data.chunks, arr.data.numblocks, src_crs, dst_crs)
    reprojected_data = da.map_blocks(_block_reproject, arr.data,
                                     src_coords, src_crs, dst_crs,
                                     dtype=np.uint16,
                                     chunks=chunks, **kwargs)
    return reprojected_data

In [None]:
def reprojection_rio(arr: xr.DataArray, dst_crs, **kwargs):
    mapper = ArrayTypeFunctionMapping(numpy_func=_numpy_reproject,
                                      cupy_func=_cupy_reproject,
                                      dask_func=_dask_reproject,
                                      dask_cupy_func=_dask_cupy_reproject)
    reprojected = mapper(arr)(arr, dst_crs, **kwargs)
    return reprojected

## Reprojection: some examples with several common earth projections
### Xarray-spatial provides the option to reproject the coordinate reference system (crs) of xarray DataArrays, using standard notation

In [None]:
earth_lat_lon = rioxarray.open_rasterio('/Users/ls/Downloads/elevation.tif', chunks='auto')
shade_elev_hs(earth_lat_lon)

In [None]:
web_mercator_crs = 'EPSG:3857'
web_mercator_earth = reprojection_rio(earth_lat_lon, web_mercator_crs)
shade_elev_hs(web_mercator_earth)

In [None]:
robinson_crs = 'ESRI:54030'
robinson_earth = reprojection_rio(earth_lat_lon[:,:,:], robinson_crs)
shade_elev_hs(robinson_earth)

In [None]:
transverse_mercator_crs = 'EPSG:3004'
transverse_mercator_earth = reprojection_rio(earth_lat_lon[:,:50,:], transverse_mercator_crs)
shade_elev_hs(transverse_mercator_earth)

In [None]:
lambert_crs = 'ESRI:102009'
lambert_earth_NA = reprojection_rio(earth_lat_lon[:,25:125,50:200], lambert_crs)
shade_elev_hs(lambert_earth_NA)

In [None]:
space_oblique_crs = 'EPSG:29873'
space_oblique_earth = reprojection_rio(earth_lat_lon[:,:100,:150], space_oblique_crs)
shade_elev_hs(space_oblique_earth)