
# Multi-Scale Relief Model (MSRM) in Python

This notebook adapts the original Google Earth Engine JavaScript implementation by H.A. Orengo for local Python execution.

It computes a **Multi-Scale Relief Model (MSRM)** from a DEM raster by:
1. Creating low-pass filtered surfaces at multiple radii.
2. Building relief models from the difference between consecutive surfaces.
3. Averaging all relief models into the final MSRM.
4. Optionally producing a stretched visualization and hillshade.

> Notes:
> - The original algorithm description can be found in: Orengo, H.A. & Petrie, C.A. (2018. Multi-scale relief model (MSRM): a new algorithm for the visualization of subtle topographic change of variable size in digital elevation models. *Earth Surface Processes and Landforms*, 43(6): 1361-1369. https://doi.org/10.1002/esp.4317
> - This notebook is designed to run on standard Python environments and Google Colab.


## 1) Dependency check and imports

This cell checks whether required packages are available and then imports them.

If any dependency is missing, install from the provided conda YAML (`environment_MSRM.yml`) or with `pip install <package>`.


In [None]:

import importlib
import sys

required_packages = [
    "numpy",
    "rasterio",
    "scipy",
    "matplotlib",
    "tqdm",
]

print("Python:", sys.version.split()[0])
missing = []
for pkg in required_packages:
    try:
        importlib.import_module(pkg)
        print(f"✅ {pkg}: available")
    except Exception as exc:
        print(f"❌ {pkg}: missing ({exc})")
        missing.append(pkg)

if missing:
    print("\nMissing packages:", ", ".join(missing))
    print("Install them before running the remaining cells.")
else:
    print("\nAll required packages are available.")

import os
from pathlib import Path
import math
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import rasterio
from rasterio.transform import Affine
from scipy.ndimage import uniform_filter
import matplotlib.pyplot as plt
from matplotlib.colors import LightSource
from tqdm.auto import tqdm



## 2) User inputs

Set your input and output paths, then the three algorithm parameters:
- `fmax`: maximum feature size to detect (meters)
- `fmin`: minimum feature size to detect (meters)
- `x`: scaling factor

### Colab
If using Google Colab, your input may be under `/content/drive/MyDrive/...` after mounting Drive.


In [None]:

# ===== User-configurable parameters =====
input_raster_path = "./input_dem.tif"      # DEM path
output_msrm_path = "./MSRM_output.tif"     # Output MSRM path (GeoTIFF)

# Optional outputs for visualization products
output_stretch_path = "./MSRM_stretch.tif"  # Set to None to skip writing
output_hillshade_path = "./MSRM_hillshade.tif"  # Set to None to skip writing

# Algorithm parameters (meters for fmin/fmax)
fmax = 1200.0
fmin = 40.0
x = 2

# Performance controls
use_all_cores = True
manual_workers = None   # Example: 4 (only used if use_all_cores=False)
progress_bar = True

print("Input raster:", input_raster_path)
print("Output MSRM:", output_msrm_path)
print("Parameters -> fmin:", fmin, "fmax:", fmax, "x:", x)



## 3) Utility functions

These functions implement the core processing stucture.

### Important implementation notes
- The original GEE square kernel low-pass filter is implemented here as a square moving average (`uniform_filter`) with side `(2 * radius + 1)`.
- NoData values are handled with a **NaN-aware mean filter**.
- Relief models are parallelised: each relief surface is computed in a separate process task. If tasks exceed available cores, they are automatically queued.


In [None]:
def nanmean_square_filter(arr: np.ndarray, radius: int) -> np.ndarray:
    """NaN-aware square moving average with radius in pixels."""
    if radius < 0:
        raise ValueError("radius must be >= 0")
    size = int(2 * radius + 1)
    valid = np.isfinite(arr).astype(np.float32)
    filled = np.where(np.isfinite(arr), arr, 0.0).astype(np.float32)

    numer = uniform_filter(filled, size=size, mode="nearest")
    denom = uniform_filter(valid, size=size, mode="nearest")

    out = np.where(denom > 0, numer / np.maximum(denom, 1e-12), np.nan).astype(np.float32)
    return out


def compute_rr(transform: Affine, crs=None, bounds=None) -> float:
    """Estimate raster resolution in meters/pixel.

    For projected CRS, transform units are assumed to already be meters.
    For geographic CRS (degrees), convert degree/pixel to meter/pixel at raster mid-latitude.
    """
    x_res = abs(float(transform.a))
    y_res = abs(float(transform.e))

    if crs is not None and getattr(crs, "is_geographic", False):
        if bounds is None:
            raise ValueError("bounds are required to estimate meter resolution for geographic CRS.")

        lat = math.radians((float(bounds.bottom) + float(bounds.top)) / 2.0)

        # Approximate meters/degree at latitude (WGS84-style approximation).
        m_per_deg_lat = (
            111132.92
            - 559.82 * math.cos(2 * lat)
            + 1.175 * math.cos(4 * lat)
            - 0.0023 * math.cos(6 * lat)
        )
        m_per_deg_lon = (
            111412.84 * math.cos(lat)
            - 93.5 * math.cos(3 * lat)
            + 0.118 * math.cos(5 * lat)
        )

        x_res *= m_per_deg_lon
        y_res *= m_per_deg_lat

    rr = float((x_res + y_res) / 2.0)
    if rr <= 0:
        raise ValueError(f"Computed raster resolution must be > 0. Got {rr}.")
    return rr


def compute_i_n(fmin_val: float, fmax_val: float, rr: float, x_val: float):
    """Compute i and n indexes from the original equations with input guards."""
    rr = float(rr)
    x_val = float(x_val)
    fmin_val = float(fmin_val)
    fmax_val = float(fmax_val)

    if rr <= 0:
        raise ValueError(f"rr must be > 0. Got {rr}.")
    if x_val <= 0:
        raise ValueError(f"x must be > 0. Got {x_val}.")
    if fmax_val <= fmin_val:
        raise ValueError(f"fmax must be > fmin. Got fmin={fmin_val}, fmax={fmax_val}.")

    base_i = (fmin_val - rr) / (2 * rr)
    base_n = (fmax_val - rr) / (2 * rr)

    # Numerical safety: avoid negative values caused by round-off when fmin ~= rr.
    base_i = max(base_i, 0.0)
    base_n = max(base_n, 0.0)

    i = math.floor(base_i ** (1.0 / x_val))
    n = math.ceil(base_n ** (1.0 / x_val))
    return i, n


def rm_task(args):
    """Worker task: compute one relief model = LP(r1) - LP(r2)."""
    dem, ndx2, i, x_val = args
    r1 = int(round((ndx2 + i) ** x_val))
    r2 = int(round((ndx2 + i + 1) ** x_val))
    lp1 = nanmean_square_filter(dem, r1)
    lp2 = nanmean_square_filter(dem, r2)
    rm = (lp1 - lp2).astype(np.float32)
    return ndx2, r1, r2, rm



## 4) Load DEM and derive algorithm indexes

This cell reads your raster and computes:
- raster resolution (`rr`)
- corrected minimum feature size (`fmin_corrected`)
- indexes `i` and `n`
- number of relief models (`n - i`)


In [None]:

input_path = Path(input_raster_path)
if not input_path.exists():
    raise FileNotFoundError(f"Input raster not found: {input_path.resolve()}")

with rasterio.open(input_path) as src:
    dem = src.read(1).astype(np.float32)
    profile = src.profile.copy()
    transform = src.transform
    nodata = src.nodata
    crs = src.crs
    bounds = src.bounds

if nodata is not None:
    dem = np.where(dem == nodata, np.nan, dem)

rr = compute_rr(transform, crs=crs, bounds=bounds)

fmin_val = float(fmin)
fmax_val = float(fmax)
x_val = float(x)

fmin_corrected = max(fmin_val, rr)
if fmax_val <= fmin_corrected:
    raise ValueError("fmax must be greater than fmin (after correction to rr).")
if x_val <= 0:
    raise ValueError("x must be > 0.")

i, n = compute_i_n(fmin_corrected, fmax_val, rr, x_val)
if n <= i:
    raise ValueError(f"Invalid i/n combination: i={i}, n={n}. Adjust fmin/fmax/x.")

num_relief_models = n - i
radii_list = [int(round(k ** x_val)) for k in range(i, n + 1)]

print(f"Raster CRS: {crs}")
print(f"Raster resolution rr: {rr:.6f} m")
print(f"Corrected fmin: {fmin_corrected} m")
print(f"fmax: {fmax_val} m")
print(f"Scaling factor x: {x_val}")
print(f"i: {i}, n: {n}, relief models: {num_relief_models}")
print("Low-pass radii (pixels):", radii_list)



## 5) Parallel relief-model computation

Each relief model is computed in parallel (`LP_k - LP_k+1`).

If there are more relief models than CPU cores, tasks are queued automatically by `ProcessPoolExecutor`.


In [None]:

cpu_total = os.cpu_count() or 1
workers = cpu_total if use_all_cores else (manual_workers or 1)
workers = max(1, int(workers))

print(f"CPU cores detected: {cpu_total}")
print(f"Workers used: {workers}")

jobs = [(dem, ndx2, i, float(x)) for ndx2 in range(num_relief_models)]
results = [None] * num_relief_models
radius_pairs = []

with ProcessPoolExecutor(max_workers=workers) as ex:
    futures = [ex.submit(rm_task, job) for job in jobs]
    iterator = as_completed(futures)
    if progress_bar:
        iterator = tqdm(iterator, total=len(futures), desc="Computing relief models")

    for fut in iterator:
        ndx2, r1, r2, rm = fut.result()
        results[ndx2] = rm
        radius_pairs.append((ndx2, r1, r2))

radius_pairs = sorted(radius_pairs, key=lambda t: t[0])
print("First relief-model radius pairs (ndx2, r1, r2):", radius_pairs[:5])



## 6) Build MSRM, normalize stretch, and hillshade

- `MSRM` results from averaging all relief models.
- Values are rounded to 3 decimals.
- Stretch uses mean and standard deviation so that roughly `±2σ` maps to `[0, 1]`.
- The stretched file is useful to calculate the hillshade, however, manual selection of min/max visualisation values using the unstretched MSRM file is recommended.


In [None]:

rm_stack = np.stack(results, axis=0).astype(np.float32)
msrm_raw = np.nanmean(rm_stack, axis=0).astype(np.float32)
msrm = np.round(msrm_raw * 1000.0) / 1000.0

mean_val = float(np.nanmean(msrm))
sigma_val = float(np.nanstd(msrm))
if sigma_val == 0:
    sigma_val = 1e-12

msrm_stretch = ((msrm - mean_val) / (sigma_val * 2.0)) + 0.5
msrm_stretch = np.clip(msrm_stretch, 0.0, 1.0).astype(np.float32)

ls = LightSource(azdeg=315, altdeg=45)
hillshade = ls.hillshade(msrm_stretch, vert_exag=1000.0, dx=rr, dy=rr).astype(np.float32)

print(f"MSRM mean: {mean_val:.6f}")
print(f"MSRM std dev: {sigma_val:.6f}")



## 7) Save outputs

This writes the main MSRM output as GeoTIFF and (optionally) stretched raster and hillshade.


In [None]:

out_profile = profile.copy()
out_profile.update(dtype="float32", count=1, compress="deflate", nodata=np.nan)

output_msrm_path = str(Path(output_msrm_path))
with rasterio.open(output_msrm_path, "w", **out_profile) as dst:
    dst.write(msrm.astype(np.float32), 1)
print("✅ Wrote:", output_msrm_path)

if output_stretch_path:
    output_stretch_path = str(Path(output_stretch_path))
    with rasterio.open(output_stretch_path, "w", **out_profile) as dst:
        dst.write(msrm_stretch.astype(np.float32), 1)
    print("✅ Wrote:", output_stretch_path)

if output_hillshade_path:
    output_hillshade_path = str(Path(output_hillshade_path))
    with rasterio.open(output_hillshade_path, "w", **out_profile) as dst:
        dst.write(hillshade.astype(np.float32), 1)
    print("✅ Wrote:", output_hillshade_path)



## 8) Quick visualization

This preview is for QA only. The exported GeoTIFF(s) are the final outputs.


In [None]:

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

im0 = axes[0].imshow(msrm, cmap="terrain")
axes[0].set_title("MSRM (unstretched)")
plt.colorbar(im0, ax=axes[0], fraction=0.046)

im1 = axes[1].imshow(msrm_stretch, cmap="viridis", vmin=0, vmax=1)
axes[1].set_title("MSRM stretched [0,1]")
plt.colorbar(im1, ax=axes[1], fraction=0.046)

im2 = axes[2].imshow(hillshade, cmap="gray", vmin=0, vmax=1)
axes[2].set_title("Hillshade from stretched MSRM")
plt.colorbar(im2, ax=axes[2], fraction=0.046)

for ax in axes:
    ax.axis("off")

plt.tight_layout()
plt.show()
