<a href="https://colab.research.google.com/github/davetew/Modern-Aerospace-Propulsion/blob/main/Shock_Tube_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# One-Dimensional Shock Tube Model


In [None]:
# Import the required Python packages
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# steger_warming_shock_tube.py
# Python translation of the provided MATLAB shock-tube script
# Requires: numpy  (and optionally matplotlib for plots)



# ---------------------------
# Optional plotting
# ---------------------------
# import matplotlib.pyplot as plt


def gen_grid2(x_prev: np.ndarray, rho_prev: np.ndarray) -> np.ndarray:
    """
    Python placeholder for MATLAB's GenGrid2(x_old, rho_old).
    Currently returns the previous grid (no adaptivity).
    Replace with your adaptive grid generator if you have one.
    """
    return x_prev.copy()


def run_shock_tube(
    p_ratio: float = 1.1,      # Ratio.p
    length_ratio: float = 1.0, # Ratio.length
    max_time: float = 1.0,     # MaxTime
    CFL: float = 0.1,
    N_mesh: int = 500,
    gamma: float = 1.4,
    print_progress: bool = True,
):
    # ---------------------------
    # Initial uniform grid (cell centers)
    # ---------------------------
    x0 = np.linspace(0.0, 1.0 + length_ratio, N_mesh)
    # dx0 = (1.0 + length_ratio) / (N_mesh - 1)  # not strictly needed

    # Grid interface between driver & driven sections (MATLAB "Diaphragm")
    diaphragm = int(np.floor(N_mesh * (1.0 - 1.0 / (1.0 + length_ratio))))

    # ---------------------------
    # Initial conditions
    # ---------------------------
    # Driven section
    driven_rho = 1.0
    driven_p = 1.0
    driven_Et = 1.0 / gamma / (gamma - 1.0)

    # Driver section
    driver_rho = p_ratio
    driver_p = p_ratio
    driver_Et = p_ratio / gamma / (gamma - 1.0)

    # State vector U = [rho, rho*u, Et] at cell centers
    U0 = np.vstack([
        np.concatenate([driver_rho * np.ones(diaphragm),
                        driven_rho * np.ones(N_mesh - diaphragm)]),
        np.concatenate([np.zeros(diaphragm),
                        np.zeros(N_mesh - diaphragm)]),
        np.concatenate([driver_Et * np.ones(diaphragm),
                        driven_Et * np.ones(N_mesh - diaphragm)])
    ])  # shape (3, N_mesh)

    # ---------------------------
    # Time marching containers
    # ---------------------------
    t_list = [0.0]
    x_list = [x0.copy()]
    U_list = [U0.copy()]  # each is (3, N_mesh)

    # ---------------------------
    # March in time
    # ---------------------------
    while t_list[-1] <= max_time:
        x_old = x_list[-1]
        U_old = U_list[-1]

        # (Potentially) generate a new grid, then build dx (MATLAB: dx = diff(...); dx = [dx(1), dx])
        x_new = gen_grid2(x_old, U_old[0, :])
        dx = np.diff(x_new)
        dx = np.concatenate(([dx[0]], dx))  # same length as N_mesh

        # Interpolate old state onto new grid (cell-centered values)
        rho = np.interp(x_new, x_old, U_old[0, :])
        u_old = U_old[1, :] / np.maximum(U_old[0, :], 1e-15)
        u = np.interp(x_new, x_old, u_old)
        Et = np.interp(x_new, x_old, U_old[2, :])

        # Derived quantities
        a = np.sqrt(np.maximum(0.0, gamma * (gamma - 1.0) * (Et / np.maximum(rho, 1e-15) - 0.5 * u**2)))
        p = rho * a**2 / gamma

        # Time step (CFL)
        dt = CFL * np.min(dx / np.maximum(a, 1e-12))
        t_new = t_list[-1] + dt

        # ---------------------------
        # Flux splitting (Steger–Warming) at vertices
        # Build E.minus and E.plus then sum => E.total
        # Each is (3, N_mesh+1) after adding boundary terms
        # ---------------------------
        # Left-moving (E.minus): interior vertices + right boundary (solid wall -> u=0)
        A = 0.5 * rho / gamma * (u - a)
        E_minus_interior = np.vstack([
            A,
            A * (u - a),
            A * (0.5 * (u - a)**2 + 0.5 * a**2 * (3.0 - gamma) / (gamma - 1.0))
        ])  # (3, N_mesh)

        aN = a[-1]; rhoN = rho[-1]
        right_vec = 0.5 * rhoN / gamma * (-aN) * np.array([
            1.0,
            -aN,
            0.5 * aN**2 + 0.5 * aN**2 * (3.0 - gamma) / (gamma - 1.0)
        ])
        E_minus = np.hstack([E_minus_interior, right_vec.reshape(3, 1)])  # (3, N_mesh+1)

        # Right-moving (E.plus): interior vertices + left boundary term
        B = 0.5 * rho / gamma
        E_plus_interior = np.vstack([
            B * ((2.0 * gamma - 1.0) * u + a),
            B * (2.0 * (gamma - 1.0) * u**2 + (u + a)**2),
            B * ((gamma - 1.0) * u**3 + 0.5 * (u + a)**3
                 + 0.5 * a**2 * (3.0 - gamma) / (gamma - 1.0) * (u + a))
        ])  # (3, N_mesh)

        a1 = a[0]; rho1 = rho[0]
        left_vec = 0.5 * rho1 / gamma * np.array([
            a1,
            a1**2,
            0.5 * a1**3 + 0.5 * a1**2 * (3.0 - gamma) / (gamma - 1.0) * a1
        ])
        E_plus = np.hstack([left_vec.reshape(3, 1), E_plus_interior])  # (3, N_mesh+1)

        # Total flux at vertices
        E_total = E_plus + E_minus  # (3, N_mesh+1)

        # Finite-volume update on cell centers
        dE = np.diff(E_total, axis=1)        # (3, N_mesh)
        U_new = U_old - (dt / np.maximum(dx, 1e-15)) * dE  # broadcast along rows

        # Store new step
        t_list.append(t_new)
        x_list.append(x_new.copy())
        U_list.append(U_new.copy())

        if print_progress:
            print(f"Iteration = {len(t_list)-1:5d}, Time = {t_new:.6f}")

        # Stop once we’ve stepped past max_time (to match MATLAB 'while Time<=MaxTime')
        if t_new > max_time:
            break

    # ---------------------------
    # Post-processing (matches MATLAB)
    # ---------------------------
    # Stack time history
    U_arr = np.stack(U_list, axis=0)  # (Nt, 3, N)
    x_arr = np.stack(x_list, axis=0)  # (Nt, N)
    t_arr = np.array(t_list)          # (Nt,)

    rho_hist = U_arr[:, 0, :]
    u_hist = U_arr[:, 1, :] / np.maximum(rho_hist, 1e-15)
    Et_hist = U_arr[:, 2, :]
    a_hist = np.sqrt(np.maximum(0.0, gamma * (gamma - 1.0) *
                                (Et_hist / np.maximum(rho_hist, 1e-15) - 0.5 * u_hist**2)))
    p_hist = rho_hist * a_hist**2 / gamma
    M_hist = u_hist / np.maximum(a_hist, 1e-15)

    # Net force acting on tube (first minus last cell pressure)
    Force = p_hist[:, 0] - p_hist[:, -1]

    # Estimate shock "strike" time and Mach number like MATLAB
    cond = (Force / Force[0]) > 0.999
    diff_bool = np.diff(cond.astype(int))
    if diff_bool.size:
        istrike = int(np.argmin(diff_bool))  # emulate [~,istrike]=min(...)
        strike_time = t_arr[istrike]
        shock_Mach_est = 1.0 / max(strike_time, 1e-12)
        print(f"Strike Time = {strike_time:.6f}")
        print(f"Shock Mach Number = {shock_Mach_est:.6f}")
    else:
        istrike = None
        strike_time = np.nan
        shock_Mach_est = np.nan
        print("Strike estimate unavailable (insufficient samples).")

    return {
        "time": t_arr, "x": x_arr, "U": U_arr,
        "rho": rho_hist, "u": u_hist, "Et": Et_hist,
        "a": a_hist, "p": p_hist, "M": M_hist,
        "Force": Force, "strike_time": strike_time,
        "shock_Mach_est": shock_Mach_est, "gamma": gamma
    }


if __name__ == "__main__":
    # Defaults chosen to mirror your MATLAB script
    results = run_shock_tube(
        p_ratio=1.1,
        length_ratio=1.0,
        max_time=1.0,
        CFL=0.1,
        N_mesh=500,
        gamma=1.4,
        print_progress=True
    )

    # ----- Example plotting (optional) -----
    # t = results["time"]; x = results["x"]; rho = results["rho"]
    # M = results["M"]; p = results["p"]
    # Tm, Xm = np.meshgrid(t, x[0, :], indexing="ij")
    # import matplotlib.pyplot as plt
    # plt.figure(); plt.contourf(Xm, Tm, rho, levels=20); plt.xlabel("x"); plt.ylabel("Time"); plt.title("Density"); plt.colorbar()
    # plt.figure(); plt.contourf(Xm, Tm, M,   levels=20); plt.xlabel("x"); plt.ylabel("Time"); plt.title("Mach Number"); plt.colorbar()
    # plt.figure(); plt.contourf(Xm, Tm, p,   levels=20); plt.xlabel("x"); plt.ylabel("Time"); plt.title("Pressure");   plt.colorbar()
    # # Conservation checks
    # mass = np.sum(rho, axis=1); energy = np.sum(results["Et"], axis=1)
    # plt.figure(); plt.plot(t, mass/mass[0], label="Mass/M0"); plt.plot(t, energy/energy[0], label="Energy/E0"); plt.legend(); plt.xlabel("Time"); plt.grid(True)
    # # Net force
    # plt.figure(); plt.plot(t, results["Force"]/results["Force"][0]); plt.xlabel("Time"); plt.ylabel("Force/Initial"); plt.grid(True)
    # plt.show()