Skip to content

Commit

Permalink
Add accelerated version of extract_river_profiles
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbartos committed Feb 21, 2022
1 parent c683077 commit 959c792
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
30 changes: 29 additions & 1 deletion pysheds/_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from heapq import heappop, heappush
import numpy as np
from numba import njit, prange
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, void
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void

# Functions for 'flowdir'

Expand Down Expand Up @@ -1466,6 +1466,34 @@ def _d8_stream_network_iter_numba(fdir, indegree, orig_indegree, startnodes):
endnode = fdir.flat[startnode]
return profiles

@njit(Tuple((List(List(int64)), DictType(int64, int64)))(int64[:,:], uint8[:],
uint8[:], int64[:]),
cache=False)
def _d8_stream_connection_iter_numba(fdir, indegree, orig_indegree, startnodes):
n = startnodes.size
profiles = [[0]]
connections = {0 : 0}
_ = profiles.pop()
_ = connections.pop(0)
for k in range(n):
startnode = startnodes.flat[k]
endnode = fdir.flat[startnode]
profile = [startnode]
while (indegree.flat[startnode] == 0):
profile.append(endnode)
indegree.flat[endnode] -= 1
if (orig_indegree.flat[endnode] > 1):
profiles.append(profile)
chain_start = profile[0]
chain_end = profile[-1]
connections[chain_start] = chain_end
# This might be inefficient if indegree still nonzero
if (indegree.flat[endnode] == 0):
profile = [endnode]
startnode = endnode
endnode = fdir.flat[startnode]
return profiles, connections

@njit(float64[:,:](int64[:,:], int64[:,:], float64[:,:]),
parallel=True,
cache=True)
Expand Down
70 changes: 69 additions & 1 deletion pysheds/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _mfd_compute_hand(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
def extract_river_network(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
routing='d8', algorithm='iterative', **kwargs):
"""
Generates river segments from accumulation and flow_direction arrays.
Generates river segments from flow direction and mask.
Parameters
----------
Expand Down Expand Up @@ -1453,6 +1453,74 @@ def extract_river_network(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32)
geo = geojson.FeatureCollection(featurelist)
return geo

def extract_profiles(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
routing='d8', algorithm='iterative', **kwargs):
"""
Generates river segments from flow direction and mask.
Parameters
----------
fdir : Raster
Flow direction data.
mask : Raster
Boolean raster indicating channelized regions
dirmap : list or tuple (length 8)
List of integer values representing the following
cardinal and intercardinal directions (in order):
[N, NE, E, SE, S, SW, W, NW]
routing : str
Routing algorithm to use:
'd8' : D8 flow directions
algorithm : str
Algorithm type to use:
'iterative' : Use an iterative algorithm (recommended).
'recursive' : Use a recursive algorithm.
Additional keyword arguments (**kwargs) are passed to self.view.
Returns
-------
profiles : list of lists of ints
A list containing a collection of river profiles. Each river profile
is a list containing the flat indices of the grid cells inside the
river segment.
connections : dict (int : int)
A dictionary describing the connectivity of the profiles. Each key
and value corresponds to the index of the river profile in profiles.
The key represents the upstream profile and the value represents the
downstream profile that it drains to.
"""
if routing.lower() == 'd8':
fdir_overrides = {'dtype' : np.int64, 'nodata' : fdir.nodata}
else:
raise NotImplementedError('Only implemented for `d8` routing.')
mask_overrides = {'dtype' : np.bool8, 'nodata' : False}
kwargs.update(fdir_overrides)
fdir = self._input_handler(fdir, **kwargs)
kwargs.update(mask_overrides)
mask = self._input_handler(mask, **kwargs)
# Find nodata cells and invalid cells
nodata_cells = self._get_nodata_cells(fdir)
invalid_cells = ~np.in1d(fdir.ravel(), dirmap).reshape(fdir.shape)
# Set nodata cells to zero
fdir[nodata_cells] = 0
fdir[invalid_cells] = 0
maskleft, maskright, masktop, maskbottom = self._pop_rim(mask, nodata=False)
masked_fdir = np.where(mask, fdir, 0).astype(np.int64)
startnodes = np.arange(fdir.size, dtype=np.int64)
endnodes = _self._flatten_fdir_numba(masked_fdir, dirmap).reshape(fdir.shape)
indegree = np.bincount(endnodes.ravel(), minlength=fdir.size).astype(np.uint8)
orig_indegree = np.copy(indegree)
startnodes = startnodes[(indegree == 0)]
profiles, connections = _self._d8_stream_connection_iter_numba(endnodes, indegree,
orig_indegree,
startnodes)
connections = dict(connections)
indices = {profile[0] : index for index, profile in enumerate(profiles)}
connections = {indices[key] : indices.setdefault(value, indices[key])
for key, value in connections.items()}
return profiles, connections

def stream_order(self, fdir, mask, dirmap=(64, 128, 1, 2, 4, 8, 16, 32),
nodata_out=0, routing='d8', algorithm='iterative', **kwargs):
"""
Expand Down

0 comments on commit 959c792

Please sign in to comment.