# Shallow Moment Tutorial (Simple)

## Imports

In [1]:
# | code-fold: true
# | code-summary: "Load packages"
# | output: false

import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi, Piecewise

from library.fvm.solver_jax import HyperbolicSolver, Settings
from library.fvm.ode import RK1
import library.fvm.reconstruction as recon
import library.fvm.timestepping as timestepping
import library.fvm.flux as flux
import library.fvm.nonconservative_flux as nc_flux
from library.model.boundary_conditions import BoundaryCondition
from library.model.models.basisfunctions import Basisfunction, Legendre_shifted
from library.model.models.basismatrices import Basismatrices
from library.misc.misc import Zstruct

from library.model.models.shallow_moments import ShallowMoments2d
import library.model.initial_conditions as IC
import library.model.boundary_conditions as BC
import library.misc.io as io
from library.mesh.mesh import compute_derivatives
from tests.pdesoft import plots_paper
import library.postprocessing.visualization as visu


import library.mesh.mesh as petscMesh
import library.postprocessing.postprocessing as postprocessing
from library.mesh.mesh import convert_mesh_to_jax
import argparse



## Model

In [2]:
level = 0
offset = 1+level
n_fields = 3 + 2 * level
settings = Settings(
    name="SME",
    output=Zstruct(
        directory=f"outputs/junction_{level}", filename="SME", output_snapshots=30
    ),
)



In [3]:
inflow_dict = { 
    0: lambda t, x, dx, q, qaux, p, n: Piecewise((0.1, t < 0.2),(q[0], True)),
    1: lambda t, x, dx, q, qaux, p, n: Piecewise((-0.3, t < 0.2),(-q[1], True)),
                }
inflow_dict.update({
    1+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level)
})
inflow_dict.update({
    1+offset+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level+1)
})

bcs = BC.BoundaryConditions(
    [
        BC.Lambda(physical_tag="inflow", prescribe_fields=inflow_dict),
        BC.Wall(physical_tag="wall"),
    ]
)

def custom_ic(x):
    Q = np.zeros(3 + 2 * level, dtype=float)
    Q[0] = 0.01
    return Q

ic = IC.UserFunction(custom_ic)

model = ShallowMoments2d(
    level=level,
    boundary_conditions=bcs,
    initial_conditions=ic,
)

main_dir = os.getenv("ZOOMY_DIR")
mesh = petscMesh.Mesh.from_gmsh(
    os.path.join(main_dir, "meshes/channel_junction/mesh_2d_coarse.msh")
    # os.path.join(main_dir, "meshes/channel_junction/mesh_2d_fine.msh")
)

mesh = convert_mesh_to_jax(mesh)
class SMESolver(HyperbolicSolver):
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):
        dudx = compute_derivatives(Q[1]/Q[0], mesh, derivatives_multi_index=[[0, 0]])[:,0]
        dvdy = compute_derivatives(Q[1+offset]/Q[0], mesh, derivatives_multi_index=[[0, 1]])[:,0]
        Qaux = Qaux.at[0].set(dudx)
        Qaux = Qaux.at[1].set(dvdy)
        return Qaux
solver = SMESolver(settings=settings)




## Solve

In [4]:
Qnew, Qaux = solver.solve(mesh, model)

[32m2025-08-30 07:18:53.921[0m | [1mINFO    [0m | [36mlibrary.fvm.solver_jax[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 1, time: 0.003386, dt: 0.003386, next write at time: 0.011111[0m
[32m2025-08-30 07:18:53.940[0m | [1mINFO    [0m | [36mlibrary.fvm.solver_jax[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 2, time: 0.003652, dt: 0.000266, next write at time: 0.022222[0m
[32m2025-08-30 07:18:53.999[0m | [1mINFO    [0m | [36mlibrary.fvm.solver_jax[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 3, time: 0.003918, dt: 0.000266, next write at time: 0.033333[0m
[32m2025-08-30 07:18:54.040[0m | [1mINFO    [0m | [36mlibrary.fvm.solver_jax[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 4, time: 0.004184, dt: 0.000266, next write at time: 0.044444[0m
[32m2025-08-30 07:18:54.060[0m | [1mINFO    [0m | [36mlibrary.fvm.solver_jax[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miterati

## Visualization

In [5]:
io.generate_vtk(os.path.join(settings.output.directory, f"{settings.output.filename}.h5"))
# postprocessing.vtk_project_2d_to_3d(model, settings, Nz=20, filename='out_3d')

In [6]:
# visu.pyvista_3d(settings.output.directory, scale=1.0)