In [None]:
import jax.numpy as jnp
from jaxtyping import Float, Array

from compressible_1d import solver, diagnose, boundary_conditions

%load_ext autoreload
%autoreload 2

In [9]:
def step(
    U_field: Float[Array, "3 n_cells"],
    n_ghost_cells: int,
    delta_x: Float,
    delta_t: Float,
    gamma: Float,
) -> Float[Array, "3 n_cells"]:
    U_field_with_ghosts: Float[Array, "3 n_cells+2*n_ghost_cells"] = (
        boundary_conditions.apply_boundary_condition(
            U_field, boundary_condition_type="reflective", n_ghosts=n_ghost_cells
        )
    )
    U_next_field: Float[Array, "3 n_cells+2*n_ghost_cells"] = jnp.zeros_like(
        U_field_with_ghosts
    )

    for i in range(n_ghost_cells, jnp.size(U_field, axis=-1) + n_ghost_cells):
        flux_plus = solver.lax_friedrichs(
            U_field_with_ghosts[:, i],
            U_field_with_ghosts[:, i + 1],
            gamma=gamma,
            diffusivity_scale=0.0,
        )
        flux_minus = solver.lax_friedrichs(
            U_field_with_ghosts[:, i - 1],
            U_field_with_ghosts[:, i],
            gamma=gamma,
            diffusivity_scale=0.0,
        )

        dU_dt = -1 / delta_x * (flux_plus - flux_minus)

        U_next_field[:, i] = U_field_with_ghosts[:, i] + delta_t * dU_dt

    # diagnose solution
    diagnose.check_all(U_next_field[:, n_ghost_cells:-n_ghost_cells], U_field)

    return U_next_field[:, n_ghost_cells:-n_ghost_cells]

NameError: name 'Float' is not defined

In [None]:
n_cells = 10
n_ghost_cells = 1  # per side
delta_t = 1e-4  # normalized
delta_x = 1 / n_cells  # normalized cell length
gamma = 1.4

U_init = jnp.zeros((3, n_cells))

U_field = U_init

# one update step

In [8]:
from compressible_1d import boundary_conditions_test

%load_ext autoreload
%autoreload 2
boundary_conditions_test.test_reflective_bc()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
