# Naive Linear Shallow Water

In [2]:
import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.config import config
import numpy as np
import numba as nb
import pandas as pd
import equinox as eqx
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange, repeat, reduce
from tqdm.notebook import tqdm, trange
from jaxtyping import Array, Float
import wandb

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
config.update("jax_enable_x64", True)


%matplotlib inline
%load_ext autoreload
%autoreload 2

## Equations

Taking the equation from [wikipedia](https://en.wikipedia.org/wiki/Shallow_water_equations#Non-conservative_form).

**Non-Conservative Form**

$$
\begin{aligned}
\frac{\partial h}{\partial t} &+ 
\frac{\partial}{\partial x}\left((H+h)u\right) +
\frac{\partial}{\partial y}\left((H+h)v\right)= 0 \\
\frac{\partial u}{\partial t} &+ u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y} - fv =
-g\frac{\partial h}{\partial x} -ku + \nu \left( \frac{\partial^2 u}{\partial x^2} + 
\frac{\partial^2 u}{\partial y^2} \right)\\
\frac{\partial v}{\partial t} &+ u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y} + fu =
-g\frac{\partial h}{\partial y} -kv + 
\nu \left( \frac{\partial^2 v}{\partial x^2} + \frac{\partial^2 v}{\partial y^2} \right)\\
\end{aligned}
$$ (eq:sw)


| Symbol | Variable | Unit | 
|:---------:|:------|:----:|
| $u$ | Zonal Velocity |  $m/s^2$ |
| $v$ | Meridial Velocity |   $m/s^2$ |
| $H$ |Mean Height |   $m$ |
| $h$ |Height Deviation |$m$ |
|$b$ | Topographical height 


**Velocities**. The $u,v$ represent the zonal and meridional velocities in the x,y directions respectively.

**Heights** ($H,h,b$). 
The $H$ represents the mean hight of the horizontal pressure surface. 
The $h$ represents the height deviation of the horizontal pressure surface from its mean height.
$b$ represents the topographical height from a reference $D$.

$$
\begin{aligned}
\eta(x,y,t) &= H(x,y) + h(x,y,t) \\
H(x,y) &= D + b(x,y)
\end{aligned}
$$

**Constants** ($f,k,\nu$). $g$ is the acceleration due to gravity, $k$ is the viscous drag coefficient, and $\nu$ is the kinematic viscosity.

**Linear Version**

$$
\begin{aligned}
\frac{\partial h}{\partial t} &+  H\left(\frac{\partial u}{\partial x} +
\frac{\partial v}{\partial y}\right)= 0 \\
\frac{\partial u}{\partial t} &- fv = -g\frac{\partial h}{\partial x} -ku \\
\frac{\partial v}{\partial t} &+ fu = -g\frac{\partial h}{\partial y} -kv\\
\end{aligned}
$$

**Advection Term**

$$
\begin{aligned}
u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y} \\
u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y}
\end{aligned}
$$


**Diffusion Term**

$$
\begin{aligned}
\nu \left( \frac{\partial^2 u}{\partial x^2} + 
\frac{\partial^2 u}{\partial y^2} \right) \\
\nu \left( \frac{\partial^2 v}{\partial x^2} + \frac{\partial^2 v}{\partial y^2} \right)
\end{aligned}
$$

## Equations

$$

$$

$$
\begin{aligned}
\partial_t q &= - \det\boldsymbol{J}(\psi,q) - \beta\partial_y\psi \\
q &= \nabla^2 \psi - \frac{1}{L_R^2}\psi \\
\psi &= \frac{f}{g} \\
u &= -\partial_y\psi \\
v &= \partial_x\psi \\
f &= 2\Omega\sin\theta_0 + \frac{1}{R}2\Omega\cos\theta_0 y
\end{aligned}
$$

where:
* $f_0=2\Omega\sin\theta_0$ is the Coriolis parameter at latitude $\theta_0$
* $\beta=\frac{1}{R}2\Omega\cos\theta_0 y$ is the $\beta$-plane approximation
* $L_R$ is the Rossby wave coefficient
* $\Omega$ is the angular frequency of rotation

Source:
* [Geophysical Fluid Dynamcis - Pedlosky](https://doi.org/10.1007/978-1-4612-4650-3)
* [Atmosphere and Oceanic Fluid Dynamics - Vallis](https://doi.org/10.1017/9781107588417)

## Read input SSH

In [24]:
n_x = 100
dx = 20e3

n_y = 101
dy = 20e3

gravity = 9.81
depth = 100.
coriolis_param = 2e-4

phase_speed = jnp.sqrt(gravity * depth)
rossby_radius = jnp.sqrt(gravity * depth) / coriolis_param


print(f"Nx: {n_x*dx:,.0f} [m]")
print(f"dx: {dx:,.0f} [m]")

print(f"Ny: {n_y*dy:,.0f} [m]")
print(f"dy: {dy:,.0f} [m]")

print(f"Phase Speed: {phase_speed:,.2f} m/s")
print(f"Rossby Radius: {rossby_radius:,.2f} [m]")

Nx: 2,000,000 [m]
dx: 20,000 [m]
Ny: 2,020,000 [m]
dy: 20,000 [m]
Phase Speed: 31.32 m/s
Rossby Radius: 156,604.60 [m]


In [8]:
dt = 0.5 * min(dx,dy) / jnp.sqrt(gravity * depth)
print(f"dt: {dt:.2f} secs")

dt: 319.28 secs


### Grid Setup

In [25]:
x = jnp.arange(n_x) * dx
y = jnp.arange(n_y) * dy

X, Y = jnp.meshgrid(y, x, indexing="ij")

### Initial Conditions


In [26]:
h0 = depth + 1.0 * jnp.exp(
    - (X - x[n_x//2]) ** 2 / rossby_radius ** 2
    - (Y - y[n_y - 2]) ** 2 /rossby_radius ** 2
)
u0 = np.zeros_like(h0)
v0 = np.zeros_like(h0)

In [35]:
u = jnp.empty((n_y, n_x))
v = jnp.empty((n_y, n_x))
h = jnp.empty((n_y, n_x))

u = u.at[...].set(u0)
v = v.at[...].set(v0)
h = h.at[...].set(h0)

### Boundaries

we should not use the boundary values at $h$.

In [38]:
h = h.at[0,:].set(jnp.nan)
h = h.at[-1,:].set(jnp.nan)
h = h.at[:,0].set(jnp.nan)
h = h.at[:,-1].set(jnp.nan)

## Timings

In [43]:
n_t = 100
ts = jnp.arange(n_t) * dt

In [44]:
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


$$
\begin{aligned}
\partial_t u &- f\bar{v} = - g\partial_x h  \\
\partial_t v  &+  f\bar{u} = - g\partial_y h\\
\partial_t h& + H\left(\partial_x u + \partial_y v\right) = 0
\end{aligned}
$$

In [None]:
def update_u()

In [None]:
v_avg = 0.25 * 

In [45]:
for t in tqdm(ts):
    
    continue

  0%|          | 0/100 [00:00<?, ?it/s]

### Plots

In [28]:
plot_range = 0.5
plot_every = 2
max_quivers = 21

In [29]:
def prepare_plot():
    fig, ax = plt.subplots(1,1,figsize=(8,6))
    cs = update_plot(0, h0, u0, v0, ax)
    plt.colorbar(cs, label=r"$\eta$ [m]")
    return fig, ax

In [30]:
def update_plot(t, h, u, v, ax):
    eta = h - depth
    
    quiver_stride = (
        slice(1, -1, n_y // max_quivers),
        slice(1, -1, n_x // max_quivers),
    )
    
    ax.clear()
    cs = ax.pcolormesh(
        x[1:-1] / 1e3,
        y[1:-1] / 1e3,
        eta[1:-1,1:-1],
        vmin=-plot_range, vmax=plot_range, cmap="RdBu_r",
    )
    
    if np.any((u[quiver_stride] != 0) | (v[quiver_stride] != 0)):
        ax.quiver(
            x[quiver_stride[1]] / 1e3,
            y[quiver_stride[0]] / 1e3,
            u[quiver_stride],
            v[quiver_stride],
            clip_on=False
        )
    
    ax.set_aspect("equal")
    ax.set_xlabel(r"$x$ [km]")
    ax.set_ylabel(r"$y$ [km]")
    ax.set_title(
        f"t={t/86_400:5.2f}, R={rossby_radius/1e3:5.1f}, c={phase_speed:5.1f} m/s"
    )
    plt.pause(0.1)
    return cs