Skip to content

jejjohnson/jaxsw

Repository files navigation

Simple Ocean Models in JAX

Motivation

Sea surface height is a gateway variable to other important ocean properties, e.g. sea surface temperature, geostrophic currents. There are many massive models that attempt to model this, e.g. NEMO, MOM6, MITGCM. However they are very expensive and quite difficult to run. So there are many small models that are useful approximations, e.g. Quasi-Geostrophic and Shallow Water. This repo attempts to showcase how we can use some modern tools to construct dynamical systems for PDEs.

What makes this different from the tons and tons of different implementations is that we will be using JAX. JAX is basically numpy on steroids because the API is very similar but we also get some of the modern toolsets along with speed. Most importantly, JAX is differentiable. Having a differentiable model is important because it allows us to:

  • Learn some of the hyperparameters if necessary
  • Embed this in machine learning models where differentiability is needed

Why Not PyTorch?

We could easily just use PyTorch. However, there are some advantanges to JAX over other languages like PyTorch and TensorFlow:

  • Familiar Numpy-Like API which is nice for newcomers in the scientific community
  • CPU/GPU/TPU capabilities with minimal code changes
  • Gradient Operators instead of storing the transformations in the tensors
  • Functional-like language which is easier to read for newcomers
  • Auto-Vectorization so we can easily parallize the operators for multiple dimensions without code changes (note: TensorFlow has this)
  • JIT compilation speeds up the code by a lot (note: both PyTorch and TensorFlow has this)

Applications

This library will be relatively general but this will be a development platform for the following applications:

  • Generate Simulations
  • Surrogate Models
  • Data Assimilation

Main Components

Without making it too complicated, we settled on a few key objects that the package will comprise of.

Domain

This will be the object to define the grids where all of the fields live. It will be easy to access the coordinates, boundaries, grids and cell volumes. We don't need to store the grid all of the time, instead we just generate it as we see fit.

Operators

This will be a suite of functions for different gradient calculations and combined operations for well-known equations. We will primarily focus on finite difference operators with the finiteDiffX package. At a later date, we can introduce spectral and finite volume methods.

Integrators

We will use the diffrax package to do the time integration. We'll use the method-of-lines technique to formulate all of our PDEs to calculate the RHS of the equation for the state at $t$. Then we can propagate them through the time integrator to get the state at $t+1$.

Params, State & Equations of Motion

We will have a general API for how we can keep store parameters, initialize states and pass thew both through the equation of motion. To handle what's differentiable and what is not, we will use the equinox package.

Configs

We will use the hydra package to keep track of the configurations and to initialize parameters for experiments.


Installation

pip

We can directly install it via pip from the

pip install "git+https://github.com/jejjohnson/jaxsw.git"

Cloning

We can also clone the git repository

git clone https://github.com/jejjohnson/jaxsw.git
cd jaxsw

poetry

The easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well

poetry install

pip

We can also install via pip as well

pip install .

Conda

We also have a conda environment with all of the equivalent dependencies.

conda env create -f environments/jax_linux_cpu.yaml
conda activate jaxsw

Contributions


Acknowledgements

  • qg_utils - useful functions for dealing with QG equations
  • jaxdf - Nice API for defining operators for PDEs.
  • jax-cfd - Nice API for defining PDEs
  • invobs-data-assimilation - Nice API for Dynamical Systems
  • MASSH - The differentiable QG and SW models applied to sea surface height interpolation.
  • qgm_pytorch - Quasi-Geostrophic Model in PyTorch
  • QGNet - QG implementation in PyTorch with convolutions.

Releases

No releases published

Packages

No packages published

Languages