# Shallow Moment Tutorial (Simple)

## Imports

In [1]:
# | code-fold: true
# | code-summary: "Load packages"
# | output: false

import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi, Piecewise

from library.fvm.solver import HyperbolicSolver, Settings
from library.fvm.ode import RK1
import library.fvm.reconstruction as recon
import library.fvm.timestepping as timestepping
import library.fvm.flux as flux
import library.fvm.nonconservative_flux as nc_flux
from library.model.boundary_conditions import BoundaryCondition
from library.model.models.basisfunctions import Basisfunction, Legendre_shifted
from library.model.models.basismatrices import Basismatrices
from library.misc.misc import Zstruct

from library.model.models.shallow_moments import ShallowMoments2d
import library.model.initial_conditions as IC
import library.model.boundary_conditions as BC
import library.misc.io as io
from library.mesh.mesh import compute_derivatives
from tests.pdesoft import plots_paper


import library.mesh.mesh as petscMesh
import library.postprocessing.postprocessing as postprocessing
from library.mesh.mesh import convert_mesh_to_jax
import argparse



In [None]:
level = 4
offset = 1+level
n_fields = 3 + 2 * level

settings = Settings(
    name="SME",
    model=Zstruct(
        level=level,
        parameters=Zstruct.from_dict({
        "g": 9.81,
        'ex': 0.,
        'ey': 0.,
        'ez': 1.,
        "C": 1.0,
        "nu": 0.000001,
        "lamda": 7,
        "rho": 1,
        "eta": 1,
        "c_slipmod": 1 / 7.0,
    }),
    ),
    solver=Zstruct(time_end=3.0, CFL=0.45),
    output=Zstruct(
        directory=f"outputs/junction_{level}", filename="SME", output_snapshots=30
    ),
)

In [3]:
inflow_dict = { 
    0: lambda t, x, dx, q, qaux, p, n: Piecewise((0.1, t < 0.2),(q[0], True)),
    1: lambda t, x, dx, q, qaux, p, n: Piecewise((-0.3, t < 0.2),(-q[1], True)),
                }
inflow_dict.update({
    1+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level)
})
inflow_dict.update({
    1+offset+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level+1)
})

bcs = BC.BoundaryConditions(
    [
        BC.Lambda(physical_tag="inflow", prescribe_fields=inflow_dict),
        BC.Wall(physical_tag="wall"),
    ]
)

def custom_ic(x):
    Q = np.zeros(3 + 2 * level, dtype=float)
    Q[0] = 0.01
    return Q

ic = IC.UserFunction(custom_ic)

model = ShallowMoments2d(
    fields=3 + 2 * level,
    aux_fields=2,
    parameters=settings.model.parameters.as_dict(),
    boundary_conditions=bcs,
    initial_conditions=ic,
    settings={"friction": ["newtonian", "slip_mod"]},
    basis=Basismatrices(basis=Legendre_shifted(order=level+1)),
)

main_dir = os.getenv("SMS")
mesh = petscMesh.Mesh.from_gmsh(
    os.path.join(main_dir, "meshes/channel_junction/mesh_2d_coarse.msh")
    # os.path.join(main_dir, "meshes/channel_junction/mesh_2d_fine.msh")
)

mesh = convert_mesh_to_jax(mesh)
class SMESolver(HyperbolicSolver):
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):
        dudx = compute_derivatives(Q[1]/Q[0], mesh, derivatives_multi_index=[[0, 0]])[:,0]
        dvdy = compute_derivatives(Q[1+offset]/Q[0], mesh, derivatives_multi_index=[[0, 1]])[:,0]
        Qaux = Qaux.at[0].set(dudx)
        Qaux = Qaux.at[1].set(dvdy)
        return Qaux
solver = SMESolver(settings)




In [None]:
Qnew, Qaux = solver.solve(mesh, model)

io.generate_vtk(os.path.join(settings.output.directory, f"{settings.name}.h5"))

[32m2025-07-25 13:32:54.670[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback[0m:[36m36[0m - [1miteration: 1, time: 0.003386, dt: 0.003386, time_stamp: 0.333333[0m
[32m2025-07-25 13:32:54.691[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback[0m:[36m36[0m - [1miteration: 2, time: 0.003652, dt: 0.000266, time_stamp: 0.333333[0m
[32m2025-07-25 13:32:54.711[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback[0m:[36m36[0m - [1miteration: 3, time: 0.003918, dt: 0.000266, time_stamp: 0.333333[0m
[32m2025-07-25 13:32:54.730[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback[0m:[36m36[0m - [1miteration: 4, time: 0.004184, dt: 0.000266, time_stamp: 0.333333[0m
[32m2025-07-25 13:32:54.755[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback[0m:[36m36[0m - [1miteration: 5, time: 0.004450, dt: 0.000266, time_stamp: 0.333333[0m
[32m2025-07-25 13:32:54.778[0m | [1mINFO  

In [None]:
postprocessing.vtk_interpolate_3d(model, settings, Nz=20)

converted 0
converted 1
converted 2
converted 3
converted 4
converted 5
converted 6
converted 7
converted 8
converted 9
write 3d: /home/ingo/Git/Zoomy/outputs/junction_0/fields3d.h5


In [None]:
import os
import numpy as np
import pyvista as pv
import panel as pn
import vtk
from glob import glob

pn.extension("vtk")

# Set up
main_dir = os.getenv("SMS")
settings = io.load_settings(f"outputs/junction_{level}")
output_dir = os.path.join(main_dir, settings.output.directory)
vtk_files = sorted(glob(os.path.join(output_dir, "out.*.vtk")))
max_vtk_files = 10
if len(vtk_files) > max_vtk_files:
    offset = int(len(vtk_files)  / 10)
    vtk_files = vtk_files[::offset]

# Constants
x_fixed, y_fixed = 0.3, 0.05
n_points = 10
scale = 4.0
scale_v = 0.3

def scale_mesh_by_height(mesh, scale=1.):
    # Only modify points if necessary
    pmesh = mesh.cell_data_to_point_data()
    points = pmesh.points.copy()
    if "0" in pmesh.point_data:
        points[:, 2] = pmesh["0"] * points[:, 2] * scale
        mesh.points = points
    return mesh

def add_velocity_field(mesh):
    # Add V field for vector (optional)
    try:
        f1 = mesh["1"]
        f2 = mesh["2"]
        V = np.column_stack((f1, f2, np.zeros_like(f1)))
        mesh["V"] = V
    except KeyError:
        pass
    return mesh

# Utility: Load mesh & update field_selector
def load_mesh(vtk_path):
    mesh = pv.read(vtk_path)
    mesh = add_velocity_field(mesh)
    mesh = scale_mesh_by_height(mesh, scale)

    return mesh

meshes = [load_mesh(vtk_path) for vtk_path in vtk_files]


# Widgets
field_selector = pn.widgets.Select(name="Select Field", options=[], sizing_mode='stretch_width')
time_slider = pn.widgets.IntSlider(name="Time Step", start=0, end=len(vtk_files) - 1, step=1, value=0, sizing_mode='stretch_width')
show_mesh_checkbox = pn.widgets.Checkbox(name="Show Mesh", value=False)


vtk_pane_container = pn.Column()
plotter = pv.Plotter()
plotter.set_background("lightgray")
vtk_pane = pn.pane.VTK(plotter.ren_win, height=500, sizing_mode="stretch_width")
vtk_pane_container.append(vtk_pane)




def update_plot(event=None):
    plotter.clear()
    # plotter = pv.Plotter(off_screen=True)
    # plotter.set_background("lightgray")
    try:
        plotter.remove_scalar_bar()
    except:
        pass
    
    mesh = meshes[time_slider.value]




    # Setup available fields
    fields = list(mesh.cell_data.keys())
    
    if field_selector.options != fields:
        field_selector.options = fields
        field_selector.value = fields[0] if fields else None
    scalar_name = field_selector.value if field_selector.value in mesh.cell_data else "0"

    vmin, vmax = mesh.get_data_range(arr_var=scalar_name)
    plotter.add_mesh(mesh, scalars=scalar_name, opacity=0.5, clim=[vmin, vmax], 
                    scalar_bar_args=dict(       
                    title=scalar_name,
                    vertical=True,             
                    interactive=False,
                    outline=False,
                    title_font_size=35,
                    label_font_size=30,
                    fmt="%.5f",))

    # if show_mesh_checkbox.value == True:
    #     plotter.add_mesh(mesh, style='wireframe', color='black', opacity=0.3)
    plotter.reset_camera()
    vtk_pane.object = plotter.ren_win
    vtk_pane.param.trigger('object') 
    
# Trigger update
time_slider.param.watch(update_plot, "value")
field_selector.param.watch(update_plot, "value")
show_mesh_checkbox.param.watch(update_plot, "value")

# Initial load
update_plot()

# Layout
sidebar = pn.Column(
    "## Controls",
    time_slider,
    field_selector,
    show_mesh_checkbox,
    width=250,
)

layout = pn.Row(sidebar,pn.Spacer(width=5), vtk_pane_container)
layout


BokehModel(combine_events=True, render_bundle={'docs_json': {'6f7450a6-3858-4d79-bf87-292d5409434b': {'version…