# Stream Power Law with Numba

### 🌊 What you'll see
- ⚙️ D8 steepest-descent routing using `@njit`.
- 🧮 Flow accumulation (drainage area) in a fully compiled loop.
- 🪨 Implicit stream power integration (Braun & Willett, 2013 style) via Newton iterations.
- 🏔️ Uplift + erosion shaping a synthetic landscape over time.

## Setup


In [None]:
!pip install numba

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit

plt.rcParams['figure.figsize'] = (6, 5)
np.random.seed(10)


### 📐 Model parameters

In [None]:
nx, ny = 180, 180
cell_size = 50.0  # metres
pixel_area = cell_size * cell_size

# Stream power parameters (Braun & Willett style)
K = 2.5e-6
m = 0.5
n = 1.0
uplift_rate = 5e-4  # m/yr
time_step = 2_000.0  # years
n_steps = 120

init_relief = 200.0
noise_amplitude = 10.0


### 🏞️ Initial topography

In [None]:
x = np.linspace(-1, 1, nx)
y = np.linspace(-1, 1, ny)
X, Y = np.meshgrid(x, y, indexing='ij')

# Baseline uplifted plateau with a gentle tilt
z0 = init_relief * np.exp(-2.5 * (X**2 + (Y + 0.3)**2))
z0 += 40.0 * X
z0 += 20.0 * Y
z0 += noise_amplitude * np.random.randn(nx, ny)
z0 -= z0.min()

plt.figure(figsize=(6, 5))
plt.imshow(z0.T, origin='lower', cmap='terrain')
plt.title('Initial elevation (m)')
plt.colorbar(label='m')
plt.xticks([])
plt.yticks([])
plt.show()


### 🧠 Numba kernels

In [None]:
@njit(cache=True)
def compute_receivers(elev, cell_size):
    nx, ny = elev.shape
    n = nx * ny
    receivers = np.full(n, -1, np.int64)
    slope = np.zeros(n, np.float64)
    for ix in range(nx):
        for iy in range(ny):
            idx = ix * ny + iy
            z_here = elev[ix, iy]
            steepest = 0.0
            best = -1
            for dx in (-1, 0, 1):
                for dy in (-1, 0, 1):
                    if dx == 0 and dy == 0:
                        continue
                    nx_ix = ix + dx
                    ny_iy = iy + dy
                    if nx_ix < 0 or nx_ix >= nx or ny_iy < 0 or ny_iy >= ny:
                        continue
                    dz = z_here - elev[nx_ix, ny_iy]
                    if dz <= 0.0:
                        continue
                    dist = cell_size * np.sqrt(dx * dx + dy * dy)
                    local_slope = dz / dist
                    if local_slope > steepest:
                        steepest = local_slope
                        best = nx_ix * ny + ny_iy
            receivers[idx] = best
            slope[idx] = steepest
    return receivers, slope

@njit(cache=True)
def topological_order(receivers):
    n = receivers.size
    indegree = np.zeros(n, np.int64)
    stack = np.empty(n, np.int64)
    for i in range(n):
        r = receivers[i]
        if r != -1:
            indegree[r] += 1
    head = 0
    tail = 0
    queue = np.empty(n, np.int64)
    for i in range(n):
        if indegree[i] == 0:
            queue[tail] = i
            tail += 1
    while head < tail:
        node = queue[head]
        head += 1
        stack[head - 1] = node
        receiver = receivers[node]
        if receiver != -1:
            indegree[receiver] -= 1
            if indegree[receiver] == 0:
                queue[tail] = receiver
                tail += 1
    return stack[:head]

@njit(cache=True)
def accumulate_area(order, receivers, pixel_area):
    area = np.full(receivers.size, pixel_area, np.float64)
    for idx in range(order.size - 1, -1, -1):
        node = order[idx]
        rec = receivers[node]
        if rec != -1:
            area[rec] += area[node]
    return area

@njit(cache=True)
def implicit_stream_power(elev, receivers, order, area, dt, K, m, n, uplift, cell_size):
    new_elev = elev.copy()
    nx, ny = elev.shape
    for idx in range(order.size - 1, -1, -1):
        node = order[idx]
        rec = receivers[node]
        ix = node // ny
        iy = node % ny
        z_old = elev[ix, iy]
        if rec == -1:
            new_elev[ix, iy] = z_old + uplift * dt
            continue
        rx = rec // ny
        ry = rec % ny
        z_rec = new_elev[rx, ry]
        drainage = max(area[node], pixel_area)
        A_term = drainage ** m
        target = z_old + uplift * dt
        z_guess = max(z_rec + 0.1, target)
        for _ in range(40):
            slope = max((z_guess - z_rec) / cell_size, 1e-5)
            erosion = K * A_term * (slope ** n)
            f = z_guess - z_old - dt * (uplift - erosion)
            if abs(f) < 1e-6:
                break
            der = 1.0 - dt * (K * A_term * n * (slope ** (n - 1)) * (1.0 / cell_size))
            z_guess -= f / der
            if z_guess <= z_rec:
                z_guess = z_rec + 1e-4
        new_elev[ix, iy] = z_guess
    return new_elev

@njit(cache=True)
def update_landscape(elev, dt, K, m, n, uplift, cell_size):
    receivers, slope = compute_receivers(elev, cell_size)
    order = topological_order(receivers)
    area = accumulate_area(order, receivers, cell_size * cell_size)
    elev_next = implicit_stream_power(elev, receivers, order, area, dt, K, m, n, uplift, cell_size)
    relief = elev_next.max() - elev_next.min()
    return elev_next, receivers, slope, area, relief


### ▶️ Run the simulation

In [None]:
elevation = z0.copy()
relief_history = []
mean_height = []
area = np.full(nx * ny, pixel_area)

for step in range(n_steps):
    elevation, receivers, slope, area, relief = update_landscape(
        elevation, time_step, K, m, n, uplift_rate, cell_size
    )
    relief_history.append(relief)
    mean_height.append(elevation.mean())

    # soft edge condition: clamp base level at border
    base = elevation.min()
    elevation[0, :] = base
    elevation[-1, :] = base
    elevation[:, 0] = base
    elevation[:, -1] = base


### 🖼️ Visualise final state

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)
imgs = [
    (z0.T, 'Initial elevation (m)'),
    (elevation.T, 'Final elevation (m)'),
    ((elevation - z0).T, 'Elevation change (m)')
]
for ax, (data, title) in zip(axes, imgs):
    im = ax.imshow(data, origin='lower', cmap='terrain')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.show()


### 📈 Diagnostics

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
ax[0].plot(np.arange(n_steps) * time_step / 1_000, relief_history, color='tab:red')
ax[0].set_xlabel('Time (kyr)')
ax[0].set_ylabel('Relief (m)')
ax[0].set_title('Relief evolution')

ax[1].plot(np.arange(n_steps) * time_step / 1_000, mean_height, color='tab:blue')
ax[1].set_xlabel('Time (kyr)')
ax[1].set_ylabel('Mean elevation (m)')
ax[1].set_title('Mean elevation evolution')
plt.show()


### 🧭 Flow routing sanity check

In [None]:
drainage_area = area.reshape(nx, ny)
plt.figure(figsize=(6, 5))
plt.imshow(np.log10(drainage_area.T + 1), origin='lower', cmap='viridis')
plt.title('log10 drainage area (m²)')
plt.colorbar(label='log₁₀(A)')
plt.xticks([])
plt.yticks([])
plt.show()
