In [1]:
%load_ext autoreload
%autoreload 2
import jax

jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp

from compressible_1d import numerics, plot, boundary_conditions, physics


## input

In [None]:
# --- user input start ---

end_time = 5e-2  # [s]
tube_length = 1  # [m]
n_cells = 50
gamma = 1.4

boundary_condition: str = "periodic"  # periodic, reflective, transmissive

# init in primitives and not normalized
rho1 = 1.225  # density, [kg/m3]
u1 = 0.0  # velocity, [m/s]
p1 = 1e5  # pressure , [N/m2]

rho2 = 1.3  # density, [kg/m3]
u2 = 0.0  # velocity, [m/s]
p2 = 7e5  # pressure , [N/m2]

is_debug = True  # if True, conservation check will print results verbosely
is_abort = False  # if True, simulation will abort if diagnostic checks fail

n_ghost_cells = 1  # per side
if n_ghost_cells > 1:
    raise NotImplementedError("Physics not yet validated for higher order.")

# --- user input end ---


delta_x = tube_length / n_cells

U1_primitive = jnp.array([rho1, u1, p1])
U2_primitive = jnp.array([rho2, u2, p2])
U_primitive = jnp.stack([U1_primitive, U2_primitive], axis=1)
U_conserved = physics.to_conserved(U_primitive, rho_ref=rho1, p_ref=p1, gamma=gamma)

delta_t = numerics.calculate_dt(U_primitive, gamma=gamma, delta_x=delta_x)
print(f"delta_t calculated according to CFL: {delta_t:.3e}s")

n_steps = int(end_time / delta_t)
print(f"total time steps: {n_steps}")

## choose how to initialize the domain and run

In [None]:
U_init = boundary_conditions.initialize_two_domains(
    rho_left=U_conserved[0, 0],
    rho_right=U_conserved[0, 1],
    u_left=U_conserved[1, 0],
    u_right=U_conserved[1, 1],
    p_left=U_conserved[2, 0],
    p_right=U_conserved[2, 1],
    n_cells=n_cells,
)
cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    gamma=gamma,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
    boundary_condition=boundary_condition,
    solver_type="lf",
)
U_solutions_lf = numerics.run(cfd_input)
plot.plot_U_heatmaps(U_solutions_lf).show()

In [None]:
U_init = boundary_conditions.initialize_sine_wave(
    U=U_conserved[:, 0],
    a=0.01,
    mode=2,
    n_cells=n_cells,
)

cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
)
U_solutions = numerics.run(cfd_input)

## HLLC

In [4]:
# --- user input start ---

end_time = 1e-1  # [s]
tube_length = 1  # [m]
n_cells = 50
gamma = 1.4

boundary_condition: str = "periodic"  # periodic, reflective, transmissive

# init in primitives and not normalized
rho1 = 1.0 * 1.225  # density, [kg/m3]
u1 = 0.0  # velocity, [m/s]
p1 = 1.0 * 1e5  # pressure , [N/m2]

rho2 = 0.125 * 1.225  # density, [kg/m3]
u2 = 0.0  # velocity, [m/s]
p2 = 0.1 * 1e5  # pressure , [N/m2]

is_debug = False  # if True, conservation check will print results verbosely
is_abort = False  # if True, simulation will abort if diagnostic checks fail

n_ghost_cells = 1  # per side
if n_ghost_cells > 1:
    raise NotImplementedError("Physics not yet validated for higher order.")

# --- user input end ---


delta_x = tube_length / n_cells

U1_primitive = jnp.array([rho1, u1, p1])
U2_primitive = jnp.array([rho2, u2, p2])
U_primitive = jnp.stack([U1_primitive, U2_primitive], axis=1)
U_conserved = physics.to_conserved(U_primitive, rho_ref=rho1, p_ref=p1, gamma=gamma)

delta_t = numerics.calculate_dt(U_primitive, gamma=gamma, delta_x=delta_x, cmax=0.80)
print(f"delta_t calculated according to CFL: {delta_t:.3e}s")

n_steps = int(end_time / delta_t)
print(f"total time steps: {n_steps}")

delta_t calculated according to CFL: 4.733e-05s
total time steps: 2112


In [5]:
U_init = boundary_conditions.initialize_two_domains(
    rho_left=U_conserved[0, 0],
    rho_right=U_conserved[0, 1],
    u_left=U_conserved[1, 0],
    u_right=U_conserved[1, 1],
    p_left=U_conserved[2, 0],
    p_right=U_conserved[2, 1],
    n_cells=n_cells,
)
cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    gamma=gamma,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
    boundary_condition=boundary_condition,
    solver_type="hllc",
)
U_solutions_hllc = numerics.run(cfd_input)
plot.plot_U_heatmaps(physics.to_primitives(U_solutions_hllc, gamma)).show()

step 0, 	mass density max-min: 8.7500e-01
step 100, 	mass density max-min: 8.7500e-01
step 200, 	mass density max-min: 8.7500e-01
step 300, 	mass density max-min: 8.7500e-01
step 400, 	mass density max-min: 8.7500e-01
step 500, 	mass density max-min: 8.7500e-01
step 600, 	mass density max-min: 8.7500e-01
step 700, 	mass density max-min: 8.7500e-01
step 800, 	mass density max-min: 8.7500e-01
step 900, 	mass density max-min: 8.7500e-01
step 1000, 	mass density max-min: 8.7500e-01
step 1100, 	mass density max-min: 8.7500e-01
step 1200, 	mass density max-min: 8.7500e-01
step 1300, 	mass density max-min: 8.7500e-01
step 1400, 	mass density max-min: 8.7500e-01
step 1500, 	mass density max-min: 8.7500e-01
step 1600, 	mass density max-min: 8.7500e-01
step 1700, 	mass density max-min: 8.7500e-01
step 1800, 	mass density max-min: 8.7500e-01
step 1900, 	mass density max-min: 8.7500e-01
step 2000, 	mass density max-min: 8.7500e-01
step 2100, 	mass density max-min: 8.7500e-01


## sine wave

In [None]:
U_init = boundary_conditions.initialize_sine_wave_conservatives(
    U0=U_conserved[:, 0],
    gamma=gamma,
    amplitude=0.1,
    mode=2,
    n_cells=n_cells,
)

cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    gamma=gamma,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
    boundary_condition=boundary_condition,
    solver_type="hllc",
)
U_solutions_hllc = numerics.run(cfd_input)
plot.plot_U_heatmaps(physics.to_primitives(U_solutions_hllc, gamma)).show()

In [None]:
U_init = boundary_conditions.initialize_sine_wave_conservatives(
    U0=U_conserved[:, 0],
    gamma=gamma,
    amplitude=0.1,
    mode=2,
    n_cells=n_cells,
)

cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    gamma=gamma,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
    boundary_condition=boundary_condition,
    solver_type="lf",
)
U_solutions_hllc = numerics.run(cfd_input)
plot.plot_U_heatmaps(physics.to_primitives(U_solutions_hllc, gamma)).show()

## Exact Riemann

In [1]:
# --- user input start ---

end_time = 1e-1  # [s]
tube_length = 1  # [m]
n_cells = 400
gamma = 1.4

boundary_condition: str = "periodic"  # periodic, reflective, transmissive

# init in primitives and not normalized
rho1 = 1.0 * 1.225  # density, [kg/m3]
u1 = 0.0  # velocity, [m/s]
p1 = 1.0 * 1e5  # pressure , [N/m2]

rho2 = 0.125 * 1.225  # density, [kg/m3]
u2 = 0.0  # velocity, [m/s]
p2 = 0.1 * 1e5  # pressure , [N/m2]

is_debug = False  # if True, conservation check will print results verbosely
is_abort = False  # if True, simulation will abort if diagnostic checks fail

n_ghost_cells = 1  # per side
if n_ghost_cells > 1:
    raise NotImplementedError("Physics not yet validated for higher order.")

# --- user input end ---


delta_x = tube_length / n_cells

U1_primitive = jnp.array([rho1, u1, p1])
U2_primitive = jnp.array([rho2, u2, p2])
U_primitive = jnp.stack([U1_primitive, U2_primitive], axis=1)
U_conserved = physics.to_conserved(U_primitive, rho_ref=rho1, p_ref=p1, gamma=gamma)

delta_t = numerics.calculate_dt(U_primitive, gamma=gamma, delta_x=delta_x, cmax=0.40)
print(f"delta_t calculated according to CFL: {delta_t:.3e}s")

n_steps = int(end_time / delta_t)
print(f"total time steps: {n_steps}")

NameError: name 'jnp' is not defined

In [None]:
U_init = boundary_conditions.initialize_two_domains(
    rho_left=U_conserved[0, 0],
    rho_right=U_conserved[0, 1],
    u_left=U_conserved[1, 0],
    u_right=U_conserved[1, 1],
    p_left=U_conserved[2, 0],
    p_right=U_conserved[2, 1],
    n_cells=n_cells,
)
cfd_input = numerics.Input(
    delta_x=delta_x,
    delta_t=delta_t,
    U_init=U_init,
    gamma=gamma,
    n_steps=n_steps,
    n_ghost_cells=n_ghost_cells,
    is_debug=is_debug,
    is_abort=is_abort,
    boundary_condition=boundary_condition,
    solver_type="exact",
)
U_solutions = numerics.run(cfd_input)
# plot.plot_U_heatmaps(physics.to_primitives(U_solutions, gamma)).show()
plot.plot_U_heatmaps(U_solutions).show()

## Plotting

In [None]:
# visualize result
plot.plot_U_heatmaps(U_solutions).show()
plot.plot_U_field(U_solutions).show()

In [None]:
"""
Exact 1D Riemann solver for the Euler equations (ideal gas).

Implements the classical two-acoustic-waves + contact structure and
solves for the star-region pressure p* via a safeguarded Newton method
with PVRS/two-rarefaction/two-shock initial guesses. After finding
(p*, u*), the module can sample the self-similar solution U(x/t).

Author: ChatGPT (for Henri @ Proxima Fusion)
License: MIT
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Iterable
import math
import numpy as np

ArrayLike = Iterable[float] | np.ndarray


@dataclass
class Primitive:
    rho: float
    u: float
    p: float

    def sound_speed(self, gamma: float) -> float:
        return math.sqrt(max(gamma * self.p / self.rho, 0.0))


@dataclass
class StarRegion:
    p: float
    u: float
    rhoL: float
    rhoR: float


@dataclass
class WaveSpeeds:
    # Left wave: either rarefaction with head/tail or single shock speed
    SL_head: float
    SL_tail: float | None  # None if left wave is a shock
    SR_head: float
    SR_tail: float | None  # None if right wave is a shock
    SM: float  # contact speed (= u*)


@dataclass
class Solution:
    star: StarRegion
    speeds: WaveSpeeds

    def sample(
        self, xi: ArrayLike, left: Primitive, right: Primitive, gamma: float
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Sample primitive variables (rho,u,p) at similarity coordinates xi = x/t.

        Parameters
        ----------
        xi : array-like
            one or many similarity coordinates x/t
        left, right : Primitive
            Initial left/right states
        gamma : float
            Ratio of specific heats

        Returns
        -------
        (rho, u, p) arrays with shape of xi
        """
        xi = np.atleast_1d(np.asarray(xi, dtype=float))
        L = left
        R = right
        aL = L.sound_speed(gamma)
        aR = R.sound_speed(gamma)
        pstar, ustar, rhoLstar, rhoRstar = (
            self.star.p,
            self.star.u,
            self.star.rhoL,
            self.star.rhoR,
        )
        aLstar = math.sqrt(gamma * pstar / rhoLstar)
        aRstar = math.sqrt(gamma * pstar / rhoRstar)

        # Unpack speeds
        SLh, SLt = self.speeds.SL_head, self.speeds.SL_tail
        SRh, SRt = self.speeds.SR_head, self.speeds.SR_tail
        SM = self.speeds.SM

        rho = np.empty_like(xi)
        u = np.empty_like(xi)
        p = np.empty_like(xi)

        # Helper lambdas for rarefaction self-similar profiles
        def left_rarefaction(x):
            u_x = (2.0 / (gamma + 1.0)) * (x - (L.u - aL)) + L.u
            a_x = (2.0 / (gamma + 1.0)) * (aL + (L.u - x))
            p_x = L.p * (a_x / aL) ** (2.0 * gamma / (gamma - 1.0))
            rho_x = L.rho * (a_x / aL) ** (2.0 / (gamma - 1.0))
            return rho_x, u_x, p_x

        def right_rarefaction(x):
            u_x = (2.0 / (gamma + 1.0)) * ((R.u + aR) - x) + R.u
            a_x = (2.0 / (gamma + 1.0)) * (aR + (x - R.u))
            p_x = R.p * (a_x / aR) ** (2.0 * gamma / (gamma - 1.0))
            rho_x = R.rho * (a_x / aR) ** (2.0 / (gamma - 1.0))
            return rho_x, u_x, p_x

        for i, s in enumerate(xi):
            if s <= SLh:  # to the left of the left wave head → left state
                rho[i], u[i], p[i] = L.rho, L.u, L.p
                continue

            # Left wave region
            if SLt is not None:  # left rarefaction
                if SLh < s <= SLt:
                    rho[i], u[i], p[i] = left_rarefaction(s)
                    continue
                # passed the left fan
                if s <= SM:
                    rho[i], u[i], p[i] = rhoLstar, ustar, pstar
                    continue
            else:  # left shock (single speed SLh holds the shock speed)
                if s <= SM:
                    if s <= SLh:
                        rho[i], u[i], p[i] = L.rho, L.u, L.p
                    else:
                        rho[i], u[i], p[i] = rhoLstar, ustar, pstar
                    continue

            # Middle/contact and right side
            if s == SM:
                # define as right-star by convention
                rho[i], u[i], p[i] = rhoRstar, ustar, pstar
                continue

            # Right wave region
            if SRt is not None:  # right rarefaction
                if SM < s <= SRt:
                    rho[i], u[i], p[i] = rhoRstar, ustar, pstar
                    continue
                if SRt < s <= SRh:
                    rho[i], u[i], p[i] = right_rarefaction(s)
                    continue
                # to the right of the fan
                rho[i], u[i], p[i] = R.rho, R.u, R.p
            else:  # right shock
                if s <= SRh:
                    rho[i], u[i], p[i] = (
                        rhoRstar,
                        ustar,
                        pstar if s > SM else (rhoRstar, ustar, pstar),
                    )
                else:
                    rho[i], u[i], p[i] = R.rho, R.u, R.p

        return rho, u, p


# --- Core solver utilities -------------------------------------------------


def _f_rare(p: float, state: Primitive, gamma: float) -> Tuple[float, float]:
    """Return (f, df/dp) for rarefaction branch relating velocity drop to p.

    f(p) = 2 a / (gamma-1) * [ (p/pK)^{(gamma-1)/(2*gamma)} - 1 ]
    """
    pK = state.p
    rhoK = state.rho
    aK = state.sound_speed(gamma)
    if p <= 0:
        p = 1e-16
    pr = p / pK
    expo = (gamma - 1.0) / (2.0 * gamma)
    f = (2.0 * aK / (gamma - 1.0)) * (pr**expo - 1.0)
    # derivative
    df = (aK / (gamma * pK)) * pr ** (expo - 1.0)
    return f, df


def _f_shock(p: float, state: Primitive, gamma: float) -> Tuple[float, float]:
    """Return (f, df/dp) for shock branch using Rankine–Hugoniot.

    f(p) = (p - pK) * sqrt( A / (p + B) ), where
    A = 2 / ((gamma+1) rhoK), B = (gamma-1)/(gamma+1) * pK
    """
    pK = state.p
    rhoK = state.rho
    A = 2.0 / ((gamma + 1.0) * rhoK)
    B = (gamma - 1.0) / (gamma + 1.0) * pK
    denom = p + B
    denom = max(denom, 1e-300)
    root = math.sqrt(A / denom)
    f = (p - pK) * root
    # derivative via product rule
    df = root * (1.0 - 0.5 * (p - pK) / denom)
    return f, df


def _phi(p: float, L: Primitive, R: Primitive, gamma: float) -> Tuple[float, float]:
    """Compute phi(p) and derivative for root finding: fL(p) + fR(p) + uR - uL = 0."""
    # Left branch selection
    if p <= L.p:
        fL, dL = _f_rare(p, L, gamma)
    else:
        fL, dL = _f_shock(p, L, gamma)
    # Right branch selection
    if p <= R.p:
        fR, dR = _f_rare(p, R, gamma)
    else:
        fR, dR = _f_shock(p, R, gamma)
    phi = fL + fR + (R.u - L.u)
    dphi = dL + dR
    return phi, dphi


def _pvrs_guess(L: Primitive, R: Primitive, gamma: float) -> float:
    """Pressure–Velocity Riemann Solver (PVRS) initial guess.
    Toro, Eq. (4.45)."""
    aL = L.sound_speed(gamma)
    aR = R.sound_speed(gamma)
    cbar = 0.5 * (aL + aR)
    pavg = 0.5 * (L.p + R.p) - 0.125 * (R.u - L.u) * (L.rho + R.rho) * cbar
    return max(1e-16, pavg)


def _two_rarefaction_guess(L: Primitive, R: Primitive, gamma: float) -> float:
    aL = L.sound_speed(gamma)
    aR = R.sound_speed(gamma)
    term = (aL + aR) - 0.5 * (gamma - 1.0) * (R.u - L.u)
    expo = (2.0 * gamma) / (gamma - 1.0)
    p = (
        term
        / (
            aL * L.p ** (-(gamma - 1.0) / (2.0 * gamma))
            + aR * R.p ** (-(gamma - 1.0) / (2.0 * gamma))
        )
    ) ** expo
    return max(1e-16, p)


def _two_shock_guess(L: Primitive, R: Primitive, gamma: float) -> float:
    GL = math.sqrt(
        (2.0 / ((gamma + 1.0) * L.rho))
        / ((L.p * (gamma - 1.0) / (gamma + 1.0)) + 1e-300)
    )
    GR = math.sqrt(
        (2.0 / ((gamma + 1.0) * R.rho))
        / ((R.p * (gamma - 1.0) / (gamma + 1.0)) + 1e-300)
    )
    p = (GL * L.p + GR * R.p - (R.u - L.u)) / (GL + GR)
    return max(1e-16, p)


def _vacuum_would_form(L: Primitive, R: Primitive, gamma: float) -> bool:
    aL = L.sound_speed(gamma)
    aR = R.sound_speed(gamma)
    return (R.u - L.u) >= (2.0 / (gamma - 1.0)) * (aL + aR)


def exact_star_state(
    left: Primitive,
    right: Primitive,
    gamma: float = 1.4,
    tol: float = 1e-10,
    max_iter: int = 100,
) -> StarRegion:
    """Solve for (p*, u*, rho*_L, rho*_R) of the exact Riemann solution.

    Uses a safeguarded Newton iteration on phi(p)=0 with dynamic branch
    selection and fallbacks.
    """
    L, R = left, right

    if L.rho <= 0 or R.rho <= 0 or L.p <= 0 or R.p <= 0:
        raise ValueError("Non-physical input: rho>0 and p>0 required.")

    if _vacuum_would_form(L, R, gamma):
        raise ValueError(
            "Vacuum formation detected by expansion criterion. Handle vacuum case separately if needed."
        )

    # Initial guess logic
    p0 = _pvrs_guess(L, R, gamma)
    # Clip guess between small floor and a large cap to avoid overflow
    p = min(max(p0, 1e-12), 1e12)

    # Bracketing for a fallback bisection if Newton misbehaves
    # We'll try to expand a bracket around p
    a, b = min(L.p, R.p) * 1e-6, max(L.p, R.p) * 1e3

    for it in range(max_iter):
        phi, dphi = _phi(p, L, R, gamma)
        # Newton step with damping if derivative small
        if abs(dphi) > 1e-14:
            dp = -phi / dphi
            # Dampen overly large steps
            p_new = max(1e-16, p + dp)
        else:
            # fallback small step
            p_new = p * 0.5

        # Keep within bracket; if outside, do bisection-like contraction
        if not (a < p_new < b):
            # Try to ensure phi(a) and phi(b) have opposite signs
            # If not, expand bounds modestly
            for _ in range(3):
                phia, _ = _phi(a, L, R, gamma)
                phib, _ = _phi(b, L, R, gamma)
                if phia * phib < 0:
                    break
                a *= 0.1
                b *= 10.0
            # Now move towards the side that reduces |phi|
            p_new = 0.5 * (a + b)

        if abs(p_new - p) <= tol * (1.0 + p):
            p = p_new
            break
        p = p_new
    else:
        # As a last resort, try a two-shock or two-rarefaction estimate and do a few Newton steps
        p = (
            _two_shock_guess(L, R, gamma)
            if p0 < min(L.p, R.p)
            else _two_rarefaction_guess(L, R, gamma)
        )
        for _ in range(20):
            phi, dphi = _phi(p, L, R, gamma)
            p = max(1e-16, p - phi / (dphi if abs(dphi) > 1e-14 else 1.0))

    # With p* obtained, compute u*
    fL, _ = _f_rare(p, L, gamma) if p <= L.p else _f_shock(p, L, gamma)
    fR, _ = _f_rare(p, R, gamma) if p <= R.p else _f_shock(p, R, gamma)
    ustar = 0.5 * (L.u + R.u + fR - fL)

    # Star densities
    if p <= L.p:  # left rarefaction
        rhoLstar = L.rho * (p / L.p) ** (1.0 / gamma)
    else:  # left shock
        rhoLstar = L.rho * (
            (p / L.p + (gamma - 1.0) / (gamma + 1.0))
            / ((gamma - 1.0) / (gamma + 1.0) * p / L.p + 1.0)
        )

    if p <= R.p:  # right rarefaction
        rhoRstar = R.rho * (p / R.p) ** (1.0 / gamma)
    else:  # right shock
        rhoRstar = R.rho * (
            (p / R.p + (gamma - 1.0) / (gamma + 1.0))
            / ((gamma - 1.0) / (gamma + 1.0) * p / R.p + 1.0)
        )

    return StarRegion(p=p, u=ustar, rhoL=rhoLstar, rhoR=rhoRstar)


def wave_speeds(
    left: Primitive, right: Primitive, star: StarRegion, gamma: float = 1.4
) -> WaveSpeeds:
    L, R = left, right
    aL = L.sound_speed(gamma)
    aR = R.sound_speed(gamma)
    pstar, ustar = star.p, star.u

    # Left wave
    if pstar <= L.p:  # rarefaction
        SL_head = L.u - aL
        aLstar = math.sqrt(gamma * pstar / star.rhoL)
        SL_tail = ustar - aLstar
    else:  # shock
        AL = 2.0 / ((gamma + 1.0) * L.rho)
        BL = (gamma - 1.0) / (gamma + 1.0) * L.p
        SL = L.u - math.sqrt(AL * (pstar + BL))
        SL_head = SL
        SL_tail = None

    # Right wave
    if pstar <= R.p:  # rarefaction
        SR_head = R.u + aR
        aRstar = math.sqrt(gamma * pstar / star.rhoR)
        SR_tail = ustar + aRstar
    else:  # shock
        AR = 2.0 / ((gamma + 1.0) * R.rho)
        BR = (gamma - 1.0) / (gamma + 1.0) * R.p
        SR = R.u + math.sqrt(AR * (pstar + BR))
        SR_head = SR
        SR_tail = None

    return WaveSpeeds(
        SL_head=SL_head, SL_tail=SL_tail, SR_head=SR_head, SR_tail=SR_tail, SM=ustar
    )


def solve_exact(
    left: Primitive,
    right: Primitive,
    gamma: float = 1.4,
    tol: float = 1e-10,
    max_iter: int = 100,
) -> Solution:
    star = exact_star_state(left, right, gamma=gamma, tol=tol, max_iter=max_iter)
    speeds = wave_speeds(left, right, star, gamma=gamma)
    return Solution(star=star, speeds=speeds)


# # --- Convenience: Sod's shock tube demo -----------------------------------
# if __name__ == "__main__":
# Classic Sod problem (gamma=1.4)
L = Primitive(rho=1.0, u=0.0, p=1.0)
R = Primitive(rho=0.125, u=0.0, p=0.1)
gamma = 1.4

sol = solve_exact(L, R, gamma)
print(
    f"p*={sol.star.p:.8f}, u*={sol.star.u:.8f}, rhoL*={sol.star.rhoL:.8f}, rhoR*={sol.star.rhoR:.8f}"
)
print(
    f"Speeds: SL_head={sol.speeds.SL_head:.6f}, SL_tail={sol.speeds.SL_tail}, SM={sol.speeds.SM:.6f}, SR_tail={sol.speeds.SR_tail}, SR_head={sol.speeds.SR_head:.6f}"
)

# Sample at a few xi points
xi = np.linspace(-1.0, 1.0, 9)
rho, u, p = sol.sample(xi, L, R, gamma)
for s, rr, uu, pp in zip(xi, rho, u, p):
    print(f"xi={s:+.3f}: rho={rr:.6f}, u={uu:.6f}, p={pp:.6f}")