---
title: Quasi-Geostrophic Equations
---

In [1]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

In [2]:
import pytreeclass as pytc
import jax.numpy as jnp
import jax
from jax.experimental.ode import odeint

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

from lib._src.dynamical.base import DynamicalSystem
from lib._src.dynamical.qg import Qgm

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Dynamical System

* Equation of Motion
* Observation Operator
* Integrate

:::{note}

:::

```{figure} https://source.unsplash.com/random/400x200?beach,ocean
:name: myFigure
:alt: Random image of the beach or ocean!
:align: center

Relaxing at the beach 🏝 🌊 😎
```

```{math}
:label: my-equation
w_{t+1} = (1 + r_{t+1}) s(w_t) + y_{t+1}
```

### Equation of Motion

$$
\frac{dx}{dt} = (x_{i+1} - x_{i-2})x_{i-1}-x_i+F
$$

where $F$ is normally 8 to cause some chaotic behaviour.

$$
\begin{aligned}
\nabla \times \vec{e}+\frac{\partial \vec{b}}{\partial t}&=0 \\
\nabla \times \vec{h}-\vec{j}&=\vec{s}\_{e}
\end{aligned}
$$ (maxwell)

$$ Ax=b $$ (one-liner)

See [](#maxwell) for enlightenment and [](#one-liner) to do things on one line!

In [3]:
ny, nx = 10, 10
dx = 10e3 * jnp.ones((ny, nx))
dy = 12e3 * jnp.ones((ny, nx))
dt = 300

SSH0 = np.random.random((ny, nx))  # random.uniform(key,shape=(ny,nx))
MDT = np.random.random((ny, nx))
c = 2.5

qgm = Qgm(dx=dx, dy=dy, dt=dt, c=c, SSH=SSH0, qgiter=1, mdt=None)  # MDT

# Current trajectory
SSH0 = jnp.array(1e-2 * np.random.random((ny, nx)))

# Perturbation
dSSH = jnp.array(1e-2 * np.random.random((ny, nx)))

# Adjoint
adSSH0 = jnp.array(1e-2 * np.random.random((ny, nx)))

# Tangent test
SSH2 = qgm.step_jit(SSH0, dSSH)
print("Tangent test:")
for p in range(10):
    lambd = 10 ** (-p)

    SSH1 = qgm.step_jit(SSH0 + lambd * dSSH, dSSH)

    dSSH1 = qgm.step_tgl_jit(dh0=lambd * dSSH, h0=SSH0)

    mask = jnp.isnan(SSH1 - SSH2 - dSSH1)
    ps = jnp.linalg.norm((SSH1 - SSH2 - dSSH1)[~mask].flatten()) / jnp.linalg.norm(
        dSSH1[~mask]
    )

    print("%.E" % lambd, "%.E" % ps)

# Adjoint test
dSSH1 = qgm.step_tgl_jit(dh0=dSSH, h0=SSH0)
adSSH1 = qgm.step_adj_jit(adSSH0, SSH0)
mask = jnp.isnan(dSSH1 + adSSH1 + SSH0 + dSSH)

ps1 = jnp.inner(dSSH1[~mask].flatten(), adSSH0[~mask].flatten())
ps2 = jnp.inner(dSSH[~mask].flatten(), adSSH1[~mask].flatten())

print("\nAdjoint test:", ps1 / ps2)



Tangent test:
1E+00 1E-02
1E-01 5E-04
1E-02 3E-06
1E-03 3E-07
1E-04 3E-08
1E-05 3E-09
1E-06 5E-10
1E-07 5E-09
1E-08 4E-08
1E-09 4E-07

Adjoint test: 0.7405060898978333
