# Differentiable Shallow Water Equations

We present a differentiable SWE solver, based on `paddle-harmonics`. 

In [1]:
import paddle
import paddle.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from math import ceil, floor

import sys
sys.path.append("../")
from paddle_harmonics.sht import *
from paddle_harmonics.examples import ShallowWaterSolver

In [2]:
device = paddle.CUDAPlace(0) if paddle.device.cuda.device_count() > 0 else paddle.CPUPLace()

We define a shallow water solver class in `shallow_water_equations.py`

In [3]:
# initialize parameters:
nlat = 512
nlon = 2*nlat
lmax = ceil(128)
mmax = lmax
# timestepping
dt = 75
maxiter = 12*int(86400/dt)

# initialize solver class
swe_solver = ShallowWaterSolver(nlat, nlon, dt, lmax=lmax, mmax=mmax).to(device)

lons = swe_solver.lons
lats = swe_solver.lats

jj, ii = paddle.triu_indices(lmax, mmax)


W1106 16:16:32.939814 68712 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 12.3, Runtime API Version: 12.4
W1106 16:16:32.940515 68712 gpu_resources.cc:164] device: 0, cuDNN Version: 9.4.


In [4]:
uspec0 = swe_solver.galewsky_initial_condition()    

W1106 16:16:35.730650 68712 dygraph_functions.cc:88576] got different data type, run type promotion automatically, this may cause data type been changed.
W1106 16:16:35.732159 68712 multiply_fwd_func.cc:76] got different data type, run type promotion automatically, this may cause data type been changed.
W1106 16:16:35.732687 68712 dygraph_functions.cc:86562] got different data type, run type promotion automatically, this may cause data type been changed.
W1106 16:16:35.787277 68712 dygraph_functions.cc:85972] got different data type, run type promotion automatically, this may cause data type been changed.
W1106 16:16:35.829222 68712 dygraph_functions.cc:81398] got different data type, run type promotion automatically, this may cause data type been changed.


ValueError: (InvalidArgument) The type of data we are trying to retrieve (float64) does not match the type of data (float32) currently contained in the container.
  [Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():10 != phi::CppTypeToDataType<T>::Type():11.] (at /nfs/github/paddle/Paddle_fix2/paddle/phi/core/dense_tensor.cc:153)


We are now ready to run the simulation. To perform integration in time, we will use third-order Adams-Bashforth. As we are currently not interested in gradients, we can wrap the function in `paddle.no_grad()`.

In [None]:
dudtspec = paddle.zeros([3, 3, swe_solver.lmax, swe_solver.mmax], dtype=paddle.complex128)
inew = 0
inow = 1
iold = 2

uspec = uspec0.clone().to(device)

# save for later:
nskip = 50
utspec = paddle.zeros([floor(maxiter//nskip) + 1, *uspec.shape]).astype("complex128").to(device)

with paddle.no_grad():
    for iter in range(maxiter+1):
        t = iter*dt

        if iter % nskip == 0:
            utspec[iter//nskip] = uspec
            print(f"t={t/3600:.2f} hours")

        dudtspec[inew] = swe_solver.dudtspec(uspec)
        
        # update vort,div,phiv with third-order adams-bashforth.
        # forward euler, then 2nd-order adams-bashforth time steps to start.
        if iter == 0:
            dudtspec[inow] = dudtspec[inew]
            dudtspec[iold] = dudtspec[inew]
        elif iter == 1:
            dudtspec[iold] = dudtspec[inew]

        uspec = uspec + swe_solver.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] )

        # implicit hyperdiffusion for vort and div.
        uspec[1:] = swe_solver.hyperdiff * uspec[1:]
        # cycle through the indices
        inew = (inew - 1) % 3
        inow = (inow - 1) % 3
        iold = (iold - 1) % 3


In [None]:
fig = plt.figure()
im = swe_solver.plot_specdata(uspec[1], fig, cmap="twilight_shifted")
plt.show()

### Plotting a video

let us plot the vorticity for our rollout:

In [None]:
# prepare figure for animation
fig = plt.figure(figsize=(8, 6), dpi=72)
moviewriter = animation.writers['pillow'](fps=20)
moviewriter.setup(fig, './plots/zonal_jet.gif', dpi=72)

plot_pvrt = False

for i in range(utspec.shape[0]):
    t = i*nskip*dt

    if plot_pvrt:
        variable = swe_solver.potential_vorticity(utspec[i])
    else:
        variable = swe_solver.spec2grid(utspec[i, 1])

    plt.clf()
    # swe_solver.plot_griddata(variable, cmap=cmap, vmin=-0.2, vmax=1.8, title=f'zonal jet t={t/3600:.2f} hours')
    swe_solver.plot_griddata(variable, fig, cmap="twilight_shifted", antialiased=False)
    plt.draw()
    moviewriter.grab_frame()


moviewriter.finish()

## Conservation of potential vorticity

In [None]:
pvrttspec = paddle.zeros([floor(maxiter//nskip) + 1, lmax, mmax]).astype("complex128").to(device)
for i in range(utspec.shape[0]):
    pvrttspec[i] = swe_solver.grid2spec(swe_solver.potential_vorticity(utspec[i]))

In [None]:
total_vrt = pvrttspec.abs()**2
# total_vrt = utspec[..., 1, :, :].abs()**2
total_vrt = paddle.sqrt(paddle.sum(total_vrt[..., :1], dim=(-1, -2)) + paddle.sum(2 * total_vrt[..., 1:], dim=(-1, -2))).cpu()
t = nskip*dt * paddle.arange(utspec.shape[0])

plt.plot(t, total_vrt / total_vrt[0], label='Spectral Solver')
plt.title('Total vorticity over time')
plt.ylim((0,1))
plt.legend(loc='lower left')