# A minimal example of accumulating a raster predictor in GRIT

Make sure the your Python environment has the packages in `requirements.txt` installed (e.g. `pip install -r requirements.txt`).

In [1]:
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
import geopandas as gpd
from exactextract import exact_extract

## Input arguments

In [15]:
AGGREGATION_METHOD = "sum"  # valid: mean | sum
GRIT_UNIT = "segments"  # valid: segments | reaches
GRIT_REGION = "SA"  # any of the 7 GRIT continental region ID
GRIT_READ_KW = dict(where='domain = "ORIN"')  # subset catchments/rivers here, eg. by domain or catchment_id
MOCK_RASTER_LONLAT = False  # True for a lonlat mock raster, False for a projected EqualEarth raster
MOCK_RASTER_RESOLUTION = 1000  # in degrees if MOCK_RASTER_LONLAT=True, in meters otherwise

## Helper functions
These help with downloading GRIT and creating a fake predictor raster.

In [3]:
def grit_file(**kw):
    '''Download or return path to local GRIT file given file, region and epsg keywords.'''
    import urllib.request
    import zipfile

    filename = "GRITv1.0_{file}_{region}_EPSG{epsg}.gpkg".format(**kw)
    url = f"https://zenodo.org/records/17435232/files/{filename}.zip?download=1"
    
    if not os.path.exists(filename):
        print(f"Downloading {url}...")
        zip_path = filename+".zip"
        with tqdm(unit='B', unit_scale=True, unit_divisor=1024) as t:
            def reporthook(block_num, block_size, total_size):
                if t.total is None and total_size:
                    t.total = total_size
                t.update(block_size)
        
            urllib.request.urlretrieve(url, zip_path, reporthook=reporthook)
        print(f"Unzipping {zip_path}...")
        with zipfile.ZipFile(zip_path, "r") as z:
            z.extractall(".")
        os.remove(zip_path)
    return filename

In [4]:
def mock_global_raster_file(output_path, resolution=1, lonlat=True):
    '''Create a random tif raster file with a given arcmin resolution and return the path.'''
    import rasterio
    import numpy as np
    from rasterio.transform import from_origin
    if os.path.exists(output_path):
        return output_path
    res = resolution
    hemis = 180 if lonlat else 20038000
    width = int(hemis*2 / res)
    height = int(hemis / res)
    transform = from_origin(west=-hemis, north=hemis/2, xsize=res, ysize=res)
    print(f"Creating raster with of size {width}x{height} at {transform}...")
    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=1,
        dtype="float32",
        crs="EPSG:4326" if lonlat else "EPSG:8857",
        transform=transform,
        tiled=True,
        blockxsize=512,
        blockysize=512,
        compress="LZW",
    ) as dst:
        for row in tqdm(range(0, height, 512)):
            h = min(512, height - row)
            block = np.random.rand(h, width).astype("float32")
            dst.write(block, 1, window=((row, row + h), (0, width)))
    return output_path

## Download catchments and segments/reaches
These are written and cached in the current directory. Zenodo may be slow, be patient.

In [5]:
catchments = grit_file(file=GRIT_UNIT[:-1]+"_catchments", region=GRIT_REGION, epsg="4326")
rivers = grit_file(file=GRIT_UNIT, region=GRIT_REGION, epsg="4326")

Downloading https://zenodo.org/records/17435232/files/GRITv1.0_segment_catchments_SA_EPSG4326.gpkg.zip?download=1...


919MB [03:18, 4.85MB/s]                                                                                                                                


Unzipping GRITv1.0_segment_catchments_SA_EPSG4326.gpkg.zip...


## Create a mock raster
These are written and cached in the current directory.

In [6]:
out_path = f"predictor_raster_{MOCK_RASTER_RESOLUTION}{'dg' if MOCK_RASTER_LONLAT else 'm'}.tif"
predictor_raster = mock_global_raster_file(out_path, resolution=MOCK_RASTER_RESOLUTION, lonlat=MOCK_RASTER_LONLAT)

Creating raster with of size 40076x20038 at | 1000.00, 0.00,-20038000.00|
| 0.00,-1000.00, 10019000.00|
| 0.00, 0.00, 1.00|...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:40<00:00,  1.00s/it]


## Extract catchment values

In [7]:
catchment_df = gpd.read_file(catchments, **GRIT_READ_KW).set_index("global_id")
if not MOCK_RASTER_LONLAT:
    catchment_df = catchment_df.to_crs(epsg=8857)
catchment_df

Unnamed: 0_level_0,cat,catchment_id,area,domain,geometry
global_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
440000078,1,440000074,7.416900,ORIN,"MULTIPOLYGON (((-6751830 1576769.999, -6751890..."
440000158,2,440000152,174.036487,ORIN,"MULTIPOLYGON (((-6774000 1525949.999, -6774000..."
440000162,3,440000156,2.048963,ORIN,"MULTIPOLYGON (((-6764460 1518119.999, -6764490..."
440000163,4,440000156,11.668837,ORIN,"MULTIPOLYGON (((-6769320 1518449.999, -6769320..."
440000164,5,440000156,0.306000,ORIN,"MULTIPOLYGON (((-6766770 1515419.999, -6766800..."
...,...,...,...,...,...
440039511,30310,440001190,2.627663,ORIN,"MULTIPOLYGON (((-7356030 605339.999, -7355970 ..."
440039512,30311,440001190,1.380150,ORIN,"MULTIPOLYGON (((-7357170 602429.999, -7357200 ..."
440039513,30312,440001190,0.271125,ORIN,"MULTIPOLYGON (((-7358730 602579.999, -7358760 ..."
440039514,30313,440001190,3.791587,ORIN,"MULTIPOLYGON (((-7359930 603899.999, -7359885 ..."


`exact_extract` is the main raster extraction command that can be customised to specific cases. As the catchments are multipolygon features, we first explode them to single features and then aggregate them again to `global_id`s. See the exactextract documentation for more methods and optimisations: https://isciences.github.io/exactextract/index.html

In [8]:
catchment_values_exp = exact_extract(
    predictor_raster,
    catchment_df.explode().reset_index(),
    AGGREGATION_METHOD,
    progress=True,
    include_cols=["global_id"],
    output="pandas",
    strategy="raster-sequential",
)
catchment_values = catchment_values_exp.groupby("global_id").agg(AGGREGATION_METHOD)[AGGREGATION_METHOD]
catchment_values

POLYGON ((-2.0038e+07 -1.0019e+07, 2.0038e+07 -1.0019e+07, 2.0038e+07 -9.429e+06, -2.0038e+07 -9.429e+06, -2.0038e+07 -1.0019e+07)): 100%|█| 100.0/100 


global_id
440000078     3.760942
440000158    87.677047
440000162     1.424704
440000163     5.514544
440000164     0.187596
               ...    
440039511     1.403974
440039512     0.598489
440039513     0.151712
440039514     1.986900
440039515     0.357106
Name: sum, Length: 30314, dtype: float64

## Accumulate predictor in the river network
This uses the attributes of the segments/reaches to route the predictor downstream. 

In [9]:
river_lines = gpd.read_file(rivers, layer="lines", **GRIT_READ_KW).set_index("global_id")
river_lines.downstream_line_ids = river_lines.downstream_line_ids.apply(lambda s: list(map(int, s.split(",") if s else [])))
river_lines[["strahler_order", "downstream_line_ids", "width_adjusted", "drainage_area_out"]]

Unnamed: 0_level_0,strahler_order,downstream_line_ids,width_adjusted,drainage_area_out
global_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
440000078,2,[],30.000000,209.736898
440000158,1,[],30.000000,174.006904
440000162,1,[440027561],234.964997,2.055600
440000163,1,[440035186],77.881202,11.655900
440000164,1,[440035187],30.000000,0.315000
...,...,...,...,...
440039511,39,[440039510],241.248650,9182.732397
440039512,46,[440039287],357.792989,9207.853197
440039513,45,[440039512],329.114543,9206.311497
440039514,44,[440039513],221.163850,9206.180997


The predictor is partioned by the `width_adjusted` attribute (or 1 anywhere GRIT doesn't have a width) when routed downstream of a bifurcation. This function splits the input the downstream line(s).

In [10]:
def partition(input, upstream_line):
    weights_abs = river_lines.loc[upstream_line.downstream_line_ids, "width_adjusted"].fillna(1)
    weights = weights_abs/weights_abs.sum()
    return input*weights

We loop over Strahler orders, adding the output of our partition function to the downstream unit.

In [11]:
def ordered_routing(values):
    accumulated_values = values.copy()
    for i, idx in tqdm(sorted(river_lines.groupby("strahler_order").groups.items())):
        for ii, l in river_lines.loc[idx].iterrows():
            accumulated_values[l.downstream_line_ids] += partition(accumulated_values[ii], l)
    return accumulated_values

We either just accumulated the predictor with the sum method or we weight it by drainage area for the mean method.

In [12]:
if AGGREGATION_METHOD == "sum":
    accumulated_values = ordered_routing(catchment_values)
    norm = np.log10
if AGGREGATION_METHOD == "mean":
    partitioned_area = ordered_routing(catchment_df["area"])
    accumulated_values = ordered_routing(catchment_values * catchment_df["area"]) / partitioned_area
    norm = lambda x: x

accumulated_values.describe()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 583/583 [00:08<00:00, 64.91it/s]


count     30314.000000
mean       5684.553391
std       28356.137216
min           0.002735
25%          47.175773
50%         153.831363
75%        1194.769942
max      455638.263795
Name: sum, dtype: float64

## Visualise
To avoid a massive vector in the interactive map, first filter and simplify rivers.

In [13]:
river_simple = river_lines.copy()[river_lines.drainage_area_out > 10000]
river_simple["geometry"] = river_simple.geometry.simplify(tolerance=0.1, preserve_topology=True)
river_simple["value"] = accumulated_values

In [16]:
import folium
import branca.colormap as cm

vmin, vmax = accumulated_values.min(), accumulated_values.max()
colormap = cm.linear.Accent_05.scale(norm(vmin), norm(vmax))
colormap.caption = "Accumulated predictor (logged)"
scale_rivers = lambda da: da**0.5 * 0.02 + 0.1
m = folium.Map(location=[0, 0], zoom_start=7, tiles="CartoDB Positron")
folium.GeoJson(
    river_simple,
    name="Rivers",
    style_function=lambda feature: {
        "color": colormap(norm(feature["properties"]["value"])),
        "weight": scale_rivers(feature["properties"]["drainage_area_out"]),
        "opacity": 0.8
    },
    tooltip=folium.GeoJsonTooltip(fields=["value", "drainage_area_out", "name"])
).add_to(m)
colormap.add_to(m)
folium.LayerControl().add_to(m)

# Optional: zoom to data
bounds = river_simple.total_bounds
m.fit_bounds([[bounds[1], bounds[0]], [bounds[3], bounds[2]]])

m