# Shallow Moments with topography (Simple)

## Imports

In [None]:
# | code-fold: true
# | code-summary: "Load packages"
# | output: false

import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi, Piecewise

from library.fvm.solver import HyperbolicSolver, Settings
from library.fvm.ode import RK1
import library.fvm.reconstruction as recon
import library.fvm.timestepping as timestepping
import library.fvm.flux as flux
import library.fvm.nonconservative_flux as nc_flux
from library.model.boundary_conditions import BoundaryCondition
from library.model.models.basisfunctions import Basisfunction, Legendre_shifted
from library.model.models.basismatrices import Basismatrices
from library.misc.misc import Zstruct

from library.model.models.shallow_moments_topo import ShallowMomentsTopo, ShallowMomentsTopoNumerical
from library.model.models.shallow_moments import ShallowMoments2d

import library.model.initial_conditions as IC
import library.model.boundary_conditions as BC
import library.misc.io as io
from library.mesh.mesh import compute_derivatives
from tests.pdesoft import plots_paper
import library.postprocessing.visualization as visu


import library.mesh.mesh as petscMesh
import library.postprocessing.postprocessing as postprocessing
from library.mesh.mesh import convert_mesh_to_jax
import argparse

In [None]:
level = 0
offset = 1+level
n_fields = 3 + 2 * level
settings = Settings(
    name="SME",
    output=Zstruct(
        directory=f"outputs/topo_junction_{level}", filename="SME", output_snapshots=30
    ),
)

In [None]:
inflow_dict = { 
    1: lambda t, x, dx, q, qaux, p, n: Piecewise((0.1, t < 0.2),(q[0], True)),
    2: lambda t, x, dx, q, qaux, p, n: Piecewise((-0.3, t < 0.2),(-q[2], True)),
                }
inflow_dict.update({
    2+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level)
})
inflow_dict.update({
    1+offset+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level+1)
})

bcs = BC.BoundaryConditions(
    [
        BC.Lambda(physical_tag="inflow", prescribe_fields=inflow_dict),
        BC.Wall(physical_tag="wall"),
    ]
)

def custom_ic(x):
    Q = np.zeros(4 + 2 * level, dtype=float)
    Q[1] = 0.01
    return Q

ic = IC.UserFunction(custom_ic)

model = ShallowMomentsTopoNumerical(
    level=level,
    boundary_conditions=bcs,
    initial_conditions=ic,
)

main_dir = os.getenv("ZOOMY_DIR")
mesh = petscMesh.Mesh.from_gmsh(
    os.path.join(main_dir, "meshes/channel_junction/mesh_2d_coarse.msh")
    # os.path.join(main_dir, "meshes/channel_junction/mesh_2d_fine.msh")
)

mesh = convert_mesh_to_jax(mesh)
class SMESolver(HyperbolicSolver):
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):
        dudx = compute_derivatives(Q[2]/Q[1], mesh, derivatives_multi_index=[[0, 0]])[:,0]
        dvdy = compute_derivatives(Q[2+offset]/Q[1], mesh, derivatives_multi_index=[[0, 1]])[:,0]
        Qaux = Qaux.at[0].set(dudx)
        Qaux = Qaux.at[1].set(dvdy)
        return Qaux
solver = SMESolver(settings=settings)

## Solve

IF I CHANGE THE MODEL TO SHALLOWMOMENTS2D, IT WORKS FINE

The numerical model does also not work as it is. I think I need to use the old model and had a funciton substitute, 
1. stick with ha / h instead of u
2. i have a function get_hinv
3. I substitute h_inv with a new hinv in the numerical model

In [None]:
Qnew, Qaux = solver.solve(mesh, model)

JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:11.943[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 770, time: 0.009447, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:12.998[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 780, time: 0.009479, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:14.077[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 790, time: 0.009510, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:14.993[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 800, time: 0.009541, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:15.816[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 810, time: 0.009572, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:16.591[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 820, time: 0.009603, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


[32m2025-08-28 22:02:17.486[0m | [1mINFO    [0m | [36mlibrary.fvm.solver[0m:[36mlog_callback_hyperbolic[0m:[36m44[0m - [1miteration: 830, time: 0.009633, dt: 0.000003, next write at time: 0.011111[0m


JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>
JitTracer<float64[]>


## Visualization

In [None]:
io.generate_vtk(os.path.join(settings.output.directory, f"{settings.output.filename}.h5"))
# postprocessing.vtk_interpolate_3d(model, settings, Nz=20, filename='out_3d')

In [None]:
visu.pyvista_3d(settings.output.directory, scale=1.0)