In [1]:
# | code-fold: true
# | code-summary: "Load packages"
# | output: false

import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi

from library.fvm.solver import Solver, Settings
from library.fvm.ode import RK1
import library.fvm.reconstruction as recon
import library.fvm.timestepping as timestepping
import library.fvm.flux as flux
import library.fvm.nonconservative_flux as nc_flux
from library.model.boundary_conditions import BoundaryCondition
from library.model.models.basisfunctions import Basisfunction, Legendre_shifted
from library.model.models.basismatrices import Basismatrices

from library.model.model import *
import library.model.initial_conditions as IC
import library.model.boundary_conditions as BC
import library.misc.io as io
from library.mesh.mesh import compute_derivatives
from tests.pdesoft import plots_paper


import library.mesh.mesh as petscMesh
import library.postprocessing.postprocessing as postprocessing
from library.mesh.mesh import convert_mesh_to_jax
import argparse

No module named 'precice'


In [2]:
level = 4
offset = 1+level
n_fields = 3 + 2 * level
settings = Settings(
    name="ShallowMoments",
    parameters={
        "g": 9.81,
        'ex': 0.,
        'ey': 0.,
        'ez': 1.,
        "C": 1.0,
        "nu": 0.000001,
        "lamda": 7,
        "rho": 1,
        "eta": 1,
        "c_slipmod": 1 / 7.0,
    },
    reconstruction=recon.constant,
    num_flux=flux.Zero(),
    nc_flux=nc_flux.segmentpath(),
    compute_dt=timestepping.adaptive(CFL=0.45),
    time_end=3.0,
    output_snapshots=30,
    output_dir=f"outputs/junction_{level}",
)


inflow_dict = { 
    0: lambda t, x, dx, q, qaux, p, n: Piecewise((0.1, t < 0.2),(q[0], True)),
    1: lambda t, x, dx, q, qaux, p, n: Piecewise((-0.3, t < 0.2),(-q[1], True)),
                }
inflow_dict.update({
    1+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level)
})
inflow_dict.update({
    1+offset+i: lambda t, x, dx, q, qaux, p, n: 0 for i in range(level+1)
})

bcs = BC.BoundaryConditions(
    [
        BC.Lambda(physical_tag="inflow", prescribe_fields=inflow_dict),
        BC.Wall(physical_tag="wall"),
    ]
)

def custom_ic(x):
    Q = np.zeros(3 + 2 * level, dtype=float)
    Q[0] = 0.01
    return Q

ic = IC.UserFunction(custom_ic)

model = ShallowMoments2d(
    fields=3 + 2 * level,
    aux_fields=2,
    parameters=settings.parameters,
    boundary_conditions=bcs,
    initial_conditions=ic,
    settings={"friction": ["newtonian", "slip_mod"]},
    basis=Basismatrices(basis=Legendre_shifted(order=level+1)),
)

main_dir = os.getenv("SMS")
mesh = petscMesh.Mesh.from_gmsh(
    os.path.join(main_dir, "meshes/channel_junction/mesh_2d_coarse.msh")
    # os.path.join(main_dir, "meshes/channel_junction/mesh_2d_fine.msh")
)

mesh = convert_mesh_to_jax(mesh)
class SMESolver(Solver):
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):
        dudx = compute_derivatives(Q[1]/Q[0], mesh, derivatives_multi_index=[[0, 0]])[:,0]
        dvdy = compute_derivatives(Q[1+offset]/Q[0], mesh, derivatives_multi_index=[[0, 1]])[:,0]
        Qaux = Qaux.at[0].set(dudx)
        Qaux = Qaux.at[1].set(dvdy)
        return Qaux
solver = SMESolver()
Qnew, Qaux = solver.jax_fvm_unsteady_semidiscrete(mesh, model, settings)

io.generate_vtk(os.path.join(settings.output_dir, f"{settings.name}.h5"))



iteration: 1.0, time: 0.013545709229399058, dt: 0.013545709229399058, time_stamp: 0.10344827586206896
iteration: 2.0, time: 0.017829238598125588, dt: 0.004283529368726528, time_stamp: 0.10344827586206896
iteration: 3.0, time: 0.022112767966852118, dt: 0.004283529368726528, time_stamp: 0.10344827586206896
iteration: 4.0, time: 0.02612638814985772, dt: 0.004013620183005602, time_stamp: 0.10344827586206896
iteration: 5.0, time: 0.02991049380891096, dt: 0.003784105659053239, time_stamp: 0.10344827586206896
iteration: 6.0, time: 0.033612501514719596, dt: 0.00370200770580864, time_stamp: 0.10344827586206896
iteration: 7.0, time: 0.03728158488102703, dt: 0.0036690833663074298, time_stamp: 0.10344827586206896
iteration: 8.0, time: 0.04092637078279678, dt: 0.0036447859017697537, time_stamp: 0.10344827586206896
iteration: 9.0, time: 0.04450964691364398, dt: 0.0035832761308471983, time_stamp: 0.10344827586206896
iteration: 10.0, time: 0.048061795110555604, dt: 0.0035521481969116213, time_stamp: 

In [3]:
postprocessing.vtk_interpolate_3d(model, settings.output_dir,  os.path.join(settings.output_dir, f"{settings.name}.h5"), Nz=20)

converted 0
converted 1
converted 2
converted 3
converted 4
converted 5
converted 6
converted 7
converted 8
converted 9
converted 10
converted 11
converted 12
converted 13
converted 14
converted 15
converted 16
converted 17
converted 18
converted 19
converted 20
converted 21
converted 22
converted 23
converted 24
converted 25
converted 26
converted 27
converted 28
converted 29
write 3d: /home/ingo/Git/sms/outputs/junction_4/fields3d.h5


In [None]:
import pyvista as pv
main_dir = os.getenv("SMS")
vtk_path = os.path.join(main_dir, os.path.join(settings.output_dir, "out.7.vtk"))
mesh = pv.read(vtk_path)

print(mesh.cell_data.keys())

scale = 4.
scale_v = 0.3

f1 = mesh['1']  # component in x (iHat)
f2 = mesh['2']  # component in y (jHat)

# Create a 2D vector field (z-component is zero)
vectors = np.column_stack((f1, f2, np.zeros_like(f1)))

mesh['V'] = vectors
mesh = mesh.cell_data_to_point_data()


# Get current point coordinates
points = mesh.points.copy()

idx = ((points[:, 0] - x_fixed)**2 + (points[:, 1] - y_fixed)**2 + (points[:, 2] - 1.)**2).argmin()


# Replace the z-coordinate
points[:, 2] = mesh['0']  * points[:, 2] * scale # index 2 is the z-component

# Apply the modified coordinates back to the mesh
mesh.points = points

# Create a line along z from 0 to 1 at fixed (x, y)
x_fixed, y_fixed = 0.3, 0.05   # choose based on your domain
n_points = 10


# Extract their z-coordinates
z_fixed = points[idx, 2]

start = (x_fixed, y_fixed, 0.0)
end = (x_fixed, y_fixed, z_fixed)
line = pv.Line(start, end, resolution=n_points - 1)

# # Sample vector field along the line
sampled = line.sample(mesh)
print(sampled['V'].shape)

# # Compute magnitude of vectors (for scaling)
sampled['V_mag'] = np.linalg.norm(sampled['V'], axis=1) * scale_v

# # Generate glyphs scaled by magnitude
glyphs = sampled.glyph(orient='V', scale='V_mag')

# Create a plotter and add mesh and glyphs
plotter = pv.Plotter()

# Add mesh with 50% transparency
plotter.add_mesh(mesh, scalars='V', opacity=0.5, show_scalar_bar=False)

# Add solid glyphs (arrows)
plotter.add_mesh(glyphs, color='white')  # or any other color you like

# Display the scene
plotter.show()

['0', '1', '2', '3', '4', 'aux_0']
0.06547667161931342
(10, 3)


Widget(value='<iframe src="http://localhost:35109/index.html?ui=P_0x7f3635dd6ae0_44&reconnect=auto" class="pyv…