# Step 1: Constant-Property Fourier Conduction Verification

Solves the transient heat conduction equation:

    rho * cp * dT/dt = div(k * grad(T))

with constant k, constant rho (for a fixed-pressure ideal gas the density varies
hyperbolically with T, but the steady-state profile is still linear), zero velocity,
no chemistry, no vibrational energy, and no species diffusion.

## Geometry
- 2D channel, Nx=1 (periodic in x), Ny=40
- Bottom wall: Dirichlet T = T_bottom
- Top wall: Dirichlet T = T_top

## Expected steady-state solution
    T(y) = T_bottom + (T_top - T_bottom) * y / H  (exactly linear)
    q_y  = -k0 * (T_top - T_bottom) / H           (constant, equal to analytical flux)

## Verification criteria
1. T profile is linear (L2 error vs analytical < machine precision)
2. Heat flux is constant and equals -k0 * (T_top - T_bottom) / H
3. Velocity remains zero
4. Pressure remains approximately uniform

## Tests
- **Test 1 (eigenstate):** Start from the exact linear profile and advance one step.
  The solution must not change.
- **Test 2 (transient):** Start from uniform T_init and run until steady state.
  Final profile must be linear.

In [3]:
%load_ext autoreload
%autoreload 2

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from pathlib import Path
import sys

jax.config.update("jax_enable_x64", True)

pio.templates.default = "plotly_white"

# Make sure the src directory is on the path
repo_root = Path.cwd()
while not (repo_root / "pyproject.toml").exists() and repo_root != repo_root.parent:
    repo_root = repo_root.parent
src_path = repo_root / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

from compressible_core import chemistry_utils, energy_models, transport_models_types
from compressible_2d import (
    equation_manager,
    equation_manager_types,
    equation_manager_utils,
    numerics_types,
)
from fourier_flows_helpers import build_channel_mesh

print(jax.default_backend())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cpu


## Problem Parameters

In [23]:
# Domain
H = 1e-3        # m, channel height
Lx = 1.0        # m, width (arbitrary; Nx=1 makes x periodic)
Nx = 1          # periodic in x
Ny = 40         # cells in y

# Boundary temperatures
T_bottom = 300.0   # K
T_top    = 600.0   # K
T_init   = 0.5 * (T_bottom + T_top)  # uniform IC

# Constant thermal conductivity (transport model)
k0 = 0.1   # W/(m*K)   (4x real air; chosen for fast convergence)

# Numerics
# Acoustic CFL: dt < CFL * dy / a
#   dy = H/Ny = 2.5e-5 m,  a ~ 430 m/s  =>  dt_CFL ~ 0.4 * 2.5e-5/430 ~ 2.3e-8 s
dt            = 1e-8    # s  (inside acoustic CFL; diffusive CFL is much looser)
t_final       = 5e-3    # s  (~9 * diffusion time constant tau_ss)
save_interval = 10    # steps per snapshot

# Analytical steady-state flux
q_analytical = -k0 * (T_top - T_bottom) / H
print(f"Domain: H={H:.1e} m, Ny={Ny}, dy={H/Ny:.2e} m")
print(f"T_bottom={T_bottom} K, T_top={T_top} K, T_init={T_init} K")
print(f"k0={k0} W/(m*K)")
print(f"Analytical heat flux q_y = {q_analytical:.4e} W/m^2")
print(f"dt={dt:.1e} s, t_final={t_final:.1e} s, n_steps={int(t_final/dt)}")

Domain: H=1.0e-03 m, Ny=40, dy=2.50e-05 m
T_bottom=300.0 K, T_top=600.0 K, T_init=450.0 K
k0=0.1 W/(m*K)
Analytical heat flux q_y = -3.0000e+04 W/m^2
dt=1.0e-08 s, t_final=5.0e-03 s, n_steps=500000


## Species and Transport

In [5]:
data_dir = repo_root / "data"

species = chemistry_utils.load_species_table(
    species_names=("N2",),
    general_data_path=str(data_dir / "species.json"),
    energy_model_config=energy_models.EnergyModelConfig(
        model="bird",
        include_electronic=False,
        data_path=str(data_dir / "air_5_bird_energy.json"),
    ),
)

# Constant-property transport: k=k0, mu=0, D_s=0, eta_v=0.
# This reduces the viscous flux to pure Fourier conduction in the translational mode.
def _constant_transport(T, T_v, p, Y_s, rho):
    n_cells   = T.shape[0]
    n_species = Y_s.shape[1]
    mu    = jnp.zeros(n_cells)
    eta_t = jnp.full(n_cells, k0)     # all conductivity in translational channel
    eta_r = jnp.zeros(n_cells)
    eta_v = jnp.zeros(n_cells)
    D_s   = jnp.zeros((n_cells, n_species))
    return mu, eta_t, eta_r, eta_v, D_s

transport_model = transport_models_types.TransportModel(
    compute_transport_properties=_constant_transport
)
print("Species and transport model ready.")

Species and transport model ready.


## Mesh and Equation Manager

In [24]:
mesh = build_channel_mesh(Nx, Ny, Lx, H)

TAG_LEFT   = 1
TAG_RIGHT  = 2
TAG_BOTTOM = 3
TAG_TOP    = 4

numerics_config = numerics_types.NumericsConfig2D(
    dt=dt,
    cfl=0.4,
    dt_mode="fixed",
    integrator_scheme="forward-euler",
    spatial_scheme="first_order",
    flux_scheme="hllc",
    axisymmetric=False,
    clipping=numerics_types.ClippingConfig2D(),
)

boundary_config = equation_manager_types.BoundaryConditionConfig2D(
    tag_to_bc={
        TAG_LEFT:   {"type": "outflow"},
        TAG_RIGHT:  {"type": "outflow"},
        TAG_BOTTOM: {"type": "wall", "Tw": T_bottom, "Tvw": T_bottom},
        TAG_TOP:    {"type": "wall", "Tw": T_top,    "Tvw": T_top},
    }
)

eq_manager = equation_manager_utils.build_equation_manager(
    mesh,
    species=species,
    collision_integrals=None,
    reactions=None,
    numerics_config=numerics_config,
    boundary_config=boundary_config,
    transport_model=transport_model,
)
print("Equation manager ready.")

Equation manager ready.


## Helper: Build Initial Condition

In [7]:
y_cells = np.asarray(mesh.cell_centroids[:, 1])  # y-coordinates of all cells
n_cells = len(y_cells)

# Initial pressure: 1 atm
p_init = 101325.0  # Pa
M_N2   = float(species.molar_masses[0])
R_N2   = 8.314 / M_N2   # J/(kg*K)

def build_U(T_profile: np.ndarray) -> jnp.ndarray:
    """Build conserved state for a given temperature profile.

    Density is set so that pressure = p_init everywhere (ideal gas).
    Velocity is zero everywhere.
    T_v = T_tr (no thermal non-equilibrium).
    """
    rho_profile = p_init / (R_N2 * T_profile)
    return equation_manager_utils.compute_U_from_primitives(
        Y_s=jnp.ones((n_cells, 1)),
        rho=jnp.asarray(rho_profile),
        u=jnp.zeros(n_cells),
        v=jnp.zeros(n_cells),
        T_tr=jnp.asarray(T_profile),
        T_V=jnp.asarray(T_profile),
        equation_manager=eq_manager,
    )


def extract_T_v(U: jnp.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Return (T, v) arrays from conserved state."""
    prim = equation_manager_utils.extract_primitives(U, eq_manager)
    return np.asarray(prim.T), np.asarray(prim.v)


# Analytical steady-state profile at cell centers
T_exact = T_bottom + (T_top - T_bottom) * y_cells / H

print(f"n_cells={n_cells}, y range=[{y_cells.min():.3e}, {y_cells.max():.3e}] m")
print(f"T_exact range: [{T_exact.min():.1f}, {T_exact.max():.1f}] K")

n_cells=40, y range=[1.250e-05, 9.875e-04] m
T_exact range: [303.8, 596.2] K


---
## Test 1 — Eigenstate Preservation

The linear temperature profile is the unique steady-state of the discrete system
with the correct Dirichlet ghost-cell implementation.  Starting from this exact
profile and advancing one step must leave T unchanged to within rounding errors
(since the flux divergence is exactly zero for a linear profile with constant k).

**Pass criterion:** max |T_after - T_exact| / (T_top - T_bottom) < 1e-8

In [21]:
U_linear = build_U(T_exact)

# Advance one step from the exact solution
U_after = equation_manager.advance_one_step(U_linear, mesh, eq_manager, dt)

T_after, v_after = extract_T_v(U_after)
T_before = T_exact  # by construction

T_err_abs  = np.abs(T_after - T_before)
T_err_rel  = T_err_abs / (T_top - T_bottom)
v_err_abs  = np.abs(v_after)

PASS_TOL = 1e-8

print("=== Test 1: Eigenstate Preservation ===")
print(f"  Max |T_after - T_exact| / dT = {T_err_rel.max():.3e}  (tolerance {PASS_TOL:.0e})")
print(f"  Max |v| after 1 step         = {v_err_abs.max():.3e} m/s")
passed1 = bool(T_err_rel.max() < PASS_TOL)
print(f"  Result: {'PASS' if passed1 else 'FAIL'}")

=== Test 1: Eigenstate Preservation ===
  Max |T_after - T_exact| / dT = 3.790e-16  (tolerance 1e-08)
  Max |v| after 1 step         = 2.033e-13 m/s
  Result: PASS


In [22]:
# Plot temperature error after 1 step
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=T_err_rel,
    y=y_cells / H,
    mode="lines+markers",
    name="|T_after - T_exact| / dT",
))
fig.update_xaxes(title_text="Relative temperature error", type="log")
fig.update_yaxes(title_text="y / H")
fig.update_layout(
    title="Test 1 — Temperature error after 1 step from linear IC",
    width=600, height=450,
)
fig.show()

---
## Test 2 — Steady-State Convergence

Start from a uniform temperature T_init = (T_bottom + T_top) / 2 and run until
the solution has converged to steady state.  The final temperature profile must
be linear and the heat flux must be constant.

**Pass criteria:**
1. max |T_final(y) - T_exact(y)| / (T_top - T_bottom) < 1e-4
2. max |q_y + k0*(T_top-T_bottom)/H| / |q_analytical| < 1e-4
3. max |v| < 1e-6 m/s

In [25]:
U_uniform = build_U(np.full(n_cells, T_init))

print("Running transient simulation...")
U_hist, t_hist = equation_manager.run_scan(
    U_uniform, mesh, eq_manager,
    t_final=t_final,
    save_interval=save_interval,
)
print(f"Done. {U_hist.shape[0]} snapshots saved.")

Running transient simulation...
Done. 50001 snapshots saved.


In [26]:
# Extract final state and compute diagnostics
U_final = U_hist[-1]
T_final, v_final = extract_T_v(U_final)

# Temperature profile error
T_err_final = np.abs(T_final - T_exact) / (T_top - T_bottom)

# Heat flux diagnostic
# Compute the y-component of the face heat flux at interior y-faces.
# For a 1D-like mesh (Nx=1) the interior horizontal faces carry all of the flux.
# Approximate q_y from cell-centroid finite differences (central diff on interior):
# sort cells by y so we can diff along y
y_order   = np.argsort(y_cells)
T_sorted  = T_final[y_order]
y_sorted  = y_cells[y_order]
dy_sorted = np.diff(y_sorted)
dT_dy     = np.diff(T_sorted) / dy_sorted
q_computed = -k0 * dT_dy   # heat flux at interior faces

q_err_rel  = np.abs(q_computed - q_analytical) / np.abs(q_analytical)
v_max      = np.abs(v_final).max()

PASS_T  = 1e-4
PASS_Q  = 1e-4
PASS_V  = 1e-3   # m/s

print("=== Test 2: Steady-State Convergence ===")
print(f"  t_final = {float(t_hist[-1]):.3e} s")
print(f"  Max |T_final - T_exact| / dT = {T_err_final.max():.3e}  (tol {PASS_T:.0e})")
print(f"  Max rel heat-flux error      = {q_err_rel.max():.3e}  (tol {PASS_Q:.0e})")
print(f"  Max |v|                      = {v_max:.3e} m/s  (tol {PASS_V:.0e})")

passed2_T = bool(T_err_final.max() < PASS_T)
passed2_q = bool(q_err_rel.max() < PASS_Q)
passed2_v = bool(v_max < PASS_V)
passed2   = passed2_T and passed2_q and passed2_v

print(f"  T profile : {'PASS' if passed2_T else 'FAIL'}")
print(f"  Heat flux : {'PASS' if passed2_q else 'FAIL'}")
print(f"  Velocity  : {'PASS' if passed2_v else 'FAIL'}")
print(f"  Overall   : {'PASS' if passed2 else 'FAIL'}")

=== Test 2: Steady-State Convergence ===
  t_final = 5.000e-03 s
  Max |T_final - T_exact| / dT = 1.807e-03  (tol 1e-04)
  Max rel heat-flux error      = 1.430e-01  (tol 1e-04)
  Max |v|                      = 2.162e-06 m/s  (tol 1e-03)
  T profile : FAIL
  Heat flux : FAIL
  Velocity  : PASS
  Overall   : FAIL


In [27]:
# Plot: temperature profile at several times
N_SAMPLES = 6
n_snaps   = U_hist.shape[0]
indices   = np.unique(np.round(np.linspace(0, n_snaps - 1, N_SAMPLES)).astype(int))

fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=("Temperature profiles", "Heat-flux profiles"),
    shared_yaxes=True,
)

colors = [f"rgba(31,119,180,{a:.2f})" for a in np.linspace(0.2, 1.0, len(indices))]

for color, idx in zip(colors, indices):
    U_snap  = U_hist[idx]
    T_snap, _ = extract_T_v(U_snap)
    t_snap  = float(t_hist[idx])

    T_s  = T_snap[y_order]
    q_s  = -k0 * np.diff(T_s) / dy_sorted
    y_q  = 0.5 * (y_sorted[:-1] + y_sorted[1:]) / H  # midpoints

    is_final = idx == indices[-1]
    label    = f"t={t_snap:.2e}s" + (" (final)" if is_final else "")

    fig.add_trace(go.Scatter(
        x=T_s, y=y_sorted / H, mode="lines",
        name=label, showlegend=True,
        line=dict(color=color, width=2 if is_final else 1),
    ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=q_s, y=y_q, mode="lines",
        showlegend=False,
        line=dict(color=color, width=2 if is_final else 1),
    ), row=1, col=2)

# Analytical reference
fig.add_trace(go.Scatter(
    x=T_exact[y_order], y=y_sorted / H,
    mode="lines", name="T_exact (analytical)",
    line=dict(color="red", dash="dash", width=2),
), row=1, col=1)

fig.add_trace(go.Scatter(
    x=[q_analytical, q_analytical], y=[0, 1],
    mode="lines", name="q_analytical",
    line=dict(color="red", dash="dash", width=2),
    showlegend=True,
), row=1, col=2)

fig.update_xaxes(title_text="T [K]", row=1, col=1)
fig.update_xaxes(title_text="q_y [W/m^2]", row=1, col=2)
fig.update_yaxes(title_text="y / H", row=1, col=1)
fig.update_layout(
    title="Test 2 — Transient Fourier conduction",
    width=1000, height=500,
    legend=dict(x=0.5, y=-0.2, xanchor="center", orientation="h"),
)
fig.show()

In [28]:
# Plot: relative temperature error vs y at final time
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=T_err_final[y_order], y=y_sorted / H,
    mode="lines+markers",
    name="|T_final - T_exact| / dT",
))
fig.update_xaxes(title_text="Relative error", type="log")
fig.update_yaxes(title_text="y / H")
fig.update_layout(
    title="Test 2 — Relative temperature error at steady state",
    width=600, height=450,
)
fig.show()

In [29]:
# Plot: velocity magnitude at final time
prim_final = equation_manager_utils.extract_primitives(U_final, eq_manager)
u_mag_final = np.sqrt(np.asarray(prim_final.u)**2 + np.asarray(prim_final.v)**2)

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=u_mag_final[y_order], y=y_sorted / H,
    mode="lines+markers",
    name="|u| [m/s]",
))
fig.update_xaxes(title_text="Velocity magnitude [m/s]", type="log")
fig.update_yaxes(title_text="y / H")
fig.update_layout(
    title="Test 2 — Velocity magnitude at steady state (should be ~0)",
    width=600, height=450,
)
fig.show()

In [30]:
# Plot: pressure at final time (should remain approximately uniform)
p_final = np.asarray(prim_final.p)
p_mean  = p_final.mean()

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=p_final[y_order], y=y_sorted / H,
    mode="lines+markers",
    name="p [Pa]",
))
fig.update_xaxes(title_text="Pressure [Pa]")
fig.update_yaxes(title_text="y / H")
fig.update_layout(
    title=f"Test 2 — Pressure profile at steady state (mean={p_mean:.1f} Pa)",
    width=600, height=450,
)
fig.show()

---
## Summary

In [31]:
print("==============================")
print(" VERIFICATION SUMMARY")
print("==============================")
print(f"  Test 1 (eigenstate) : {'PASS' if passed1  else 'FAIL'}")
print(f"  Test 2 (transient)  : {'PASS' if passed2  else 'FAIL'}")
print(f"    - T profile       : {'PASS' if passed2_T else 'FAIL'}")
print(f"    - Heat flux       : {'PASS' if passed2_q else 'FAIL'}")
print(f"    - Velocity        : {'PASS' if passed2_v else 'FAIL'}")

if passed1 and passed2:
    print()
    print("All checks passed. Diffusion operator and wall BC are correct.")
else:
    print()
    print("One or more checks FAILED. Investigate the plots above.")
    if not passed1:
        print("  * Eigenstate failure: the wall BC ghost-cell temperature is wrong.")
        print("    Expected: T_ghost = 2*Tw - T_interior (Dirichlet on face)")
    if not passed2_T or not passed2_q:
        print("  * Steady-state failure: the diffusion operator or BC may be incorrect.")

 VERIFICATION SUMMARY
  Test 1 (eigenstate) : PASS
  Test 2 (transient)  : FAIL
    - T profile       : FAIL
    - Heat flux       : FAIL
    - Velocity        : PASS

One or more checks FAILED. Investigate the plots above.
  * Steady-state failure: the diffusion operator or BC may be incorrect.
