# Shallow Water Tutorial 2d with JAX (Simple)


::: {.callout-note title="Reference"}
The following verification is based on the paper: 

```
 @article{Delestre_2013, 
 title={SWASHES: a compilation of Shallow Water Analytic Solutions for Hydraulic and Environmental Studies}, 
 volume={72}, 
 ISSN={0271-2091, 1097-0363}, DOI={10.1002/fld.3741}, 
 note={arXiv:1110.0288 [physics]}, 
 number={3}, 
 journal={International Journal for Numerical Methods in Fluids}, 
 author={Delestre, Olivier and Lucas, Carine and Ksinant, Pierre-Antoine and Darboux, Frédéric and Laguerre, Christian and Vo, Thi Ngoc Tuoi and James, Francois and Cordier, Stephane}, 
 year={2013}, 
 month=may, 
 pages={269–300} 
}
```

:::

## Imports

In [1]:
# | code-fold: true
# | code-summary: "Load packages"
# | output: false

import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi, Matrix
import pytest

from sympy import MutableDenseNDimArray as Arr

from library.python.fvm.solver_jax import HyperbolicSolver, Settings
from library.python.fvm.ode import RK1
import library.python.fvm.reconstruction as recon
import library.python.fvm.timestepping as timestepping
import library.python.fvm.flux as flux
import library.python.fvm.nonconservative_flux as nc_flux
from library.model.boundary_conditions import BoundaryCondition
from library.model.basemodel import Model, Function
from attr import field, define
# from library.model.model import *
import library.model.initial_conditions as IC
import library.model.boundary_conditions as BC
import library.python.misc.io as io
from library.python.mesh.mesh import compute_derivatives
from library.python.misc.misc import Zstruct
from tests.pdesoft import plots_paper
from library.python.transformation.to_jax import JaxRuntimeModel
from library.python.transformation.to_numpy import NumpyRuntimeModel


import library.python.mesh.mesh as petscMesh
import library.python.postprocessing.postprocessing as postprocessing
from library.python.mesh.mesh import convert_mesh_to_jax
import argparse

In [9]:
@define(frozen=True, slots=True, kw_only=True)
class SWE(Model):
    dimension: int = 2
    variables: Zstruct = field(init=False, default=dimension + 1)
    aux_variables: Zstruct = field(default=2)
    _default_parameters: dict = field(
        init=False,
        factory=lambda: {"g": 9.81, "ex": 0.0, "ey": 0.0, "ez": 1.0}
        )


    def project_2d_to_3d(self):
        out = Arr.zeros(6)
        dim = self.dimension
        x = self.position[0]
        y = self.position[1]
        z = self.position[2]
        h = self.variables[0]
        U = [hu / h for hu in self.variables[1:1+dim]]
        rho_w = 1000.
        g = 9.81
        out[0] = 0
        out[1] = h
        out[2] = U[0]
        out[3] = 0 if dim == 1 else U[1]
        out[4] = 0
        out[5] = rho_w * g * h * (1-z)
        return out

    def flux(self):
        dim = self.dimension
        h = self.variables[0]
        U = Matrix([hu / h for hu in self.variables[1:1+dim]])
        g = self.parameters.g
        I = Matrix.eye(dim)
        F = Arr.zeros(self.variables.length(), dim)
        F[0, :] = list(h * U)
        F[1:, :] = h * U * U.T + g/2 * h**2 * I
        return F
    


In [11]:


bcs = BC.BoundaryConditions(
    [
        BC.Wall(physical_tag="wall"),
        BC.Wall(physical_tag="inflow"),
        BC.Wall(physical_tag="outflow"),
    ]
)

def custom_ic(x):
    Q = np.zeros(3, dtype=float)
    Q[0] = np.where(x[0] < 5., 0.005, 0.001)
    return Q

ic = IC.UserFunction(custom_ic)

model = SWE(
    dimension=2,
    boundary_conditions=bcs,
    initial_conditions=ic,
)

print(model.quasilinear_matrix())
print('1')
rmodel = NumpyRuntimeModel(model)
print("2")

q = np.linspace(1,3,3)
qaux = np.linspace(1,2,2)
param = model.parameter_values
print(q, qaux, param)
print(rmodel.quasilinear_matrix(q, qaux, param))


# main_dir = os.getenv("ZOOMY_DIR")
# mesh = petscMesh.Mesh.from_gmsh(
#     os.path.join(main_dir, "meshes/channel_quad_2d/mesh.msh")
# )

# class SWESolver(HyperbolicSolver):
#     def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):
#         dudx = compute_derivatives(Q[1]/Q[0], mesh, derivatives_multi_index=[[0, 0]])[:,0]
#         dvdy = compute_derivatives(Q[2]/Q[0], mesh, derivatives_multi_index=[[0, 1]])[:,0]
#         Qaux = Qaux.at[0].set(dudx)
#         Qaux = Qaux.at[1].set(dvdy)
#         return Qaux
    
# settings = Settings(name="ShallowWater", output=Zstruct(directory="outputs/shallow_water_2d", filename="swe.h5", clean_directory=True))

# solver = SWESolver(time_end=6, settings=settings)
# Qnew, Qaux = solver.solve(mesh, model)

[[[0, 0], [g*q0 - q1**2/q0**2, -q1*q2/q0**2], [-q1*q2/q0**2, g*q0 - q2**2/q0**2]], [[1, 0], [2*q1/q0, q2/q0], [q2/q0, 0]], [[0, 1], [0, q1/q0], [q1/q0, 2*q2/q0]]]
1


AttributeError: 'Piecewise' object has no attribute 'tolist'

In [None]:
io.generate_vtk(os.path.join(settings.output.directory, f"{settings.output.filename}.h5"))
postprocessing.vtk_project_2d_to_3d(model, settings, Nz=20, filename='out_3d')

In [None]:
fig = plots_paper.plot_swe(os.path.join(settings.output.directory, settings.output.filename + ".h5"))

In [None]:
@pytest.mark.nbworking
def test_working():
    assert True