# 0d Heatbath

Set up a heatbath with only one cell and initialize the cell firstly with only one species, but with a strong thermal non-equilibrium. This follows the verification strategy of Casseau, V. 2021.

## Profiling

In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp

from compressible_1d import (
    equation_manager as equation_manager_module,
    equation_manager_types,
    equation_manager_utils,
    chemistry_types,
    chemistry_utils,
    numerics_types,
)
from jaxtyping import Array, Float


import time


def profile_run(
    U_init: Float[Array, "n_cells n_variables"],
    equation_manager: equation_manager_types.EquationManager,
    t_final: float,
    save_interval: int = 100,
) -> tuple[
    Float[Array, "n_snapshots n_cells n_variables"], Float[Array, "n_snapshots"]
]:
    """Run simulation from t=0 to t=t_final.

    Args:
        U_init: Initial condition [n_cells, n_variables]
        equation_manager: Contains all configuration
        t_final: Final simulation time
        save_interval: Save solution every N steps

    Returns:
        U_history: Solution snapshots [n_snapshots, n_cells, n_variables]
        t_history: Time values [n_snapshots]
    """
    dt = equation_manager.numerics_config.dt

    # Check diffusive CFL at start (warns if violated)
    # check_diffusive_cfl(U_init, equation_manager)

    # Time loop
    U = U_init
    t = 0.0
    n_steps = int(t_final / dt)
    n_snapshots = int(n_steps // save_interval) + 1

    n_cells, n_variables = U_init.shape
    U_history = jnp.zeros((n_snapshots, n_cells, n_variables))
    t_history = jnp.zeros(n_snapshots)

    U_history = U_history.at[0, :, :].set(U_init)
    t_history = t_history.at[0].set(0.0)

    if True:
        advance_one_step = jax.jit(equation_manager_module.advance_one_step)
    else:
        advance_one_step = equation_manager_module.advance_one_step

    snapshot_idx = 1
    for step in range(1, n_steps + 1):
        start_time = time.perf_counter()
        U = advance_one_step(U, equation_manager)
        t += dt

        if step % save_interval == 0 and snapshot_idx < n_snapshots:
            U_history = U_history.at[snapshot_idx, :, :].set(U)
            t_history = t_history.at[snapshot_idx].set(t)
            snapshot_idx += 1

        end_time = time.perf_counter()
        print(
            f"Step {step}/{n_steps} completed in {end_time - start_time:.6f} seconds."
        )

    return U_history, t_history


T_init = 7000.0  # K
T_V_init = 300.0  # K
general_species_data_path = (
    "/home/hhoechter/tum/jaxfluids_internship/data/" "air_5_gnoffo.json"
)
equilibrium_enthalpy_data_path = (
    "/home/hhoechter/tum/jaxfluids_internship/data/"
    "air_5_gnoffo_equilibrium_enthalpy.json"
)
species_names = ["N2"]  # monoatomic case

dt = 1e-8  # s
t_final = 2 * dt  # s
save_interval = 1  # every x steps
dx = 1e-4  # m  # TODO: show that this has no impact on the results in 0D

# ------------------------------ Simulation setup ------------------------------

# create species table
all_species = chemistry_utils.load_species_from_gnoffo(
    general_data_path=general_species_data_path,
    equilibrium_enthalpy=equilibrium_enthalpy_data_path,
)
selected_species = [species for species in all_species if species.name in species_names]
species = chemistry_types.SpeciesTable.from_species_list(selected_species)

# create NumericsConfig
numerics_config = numerics_types.NumericsConfig(
    dt=dt,
    dx=dx,
    integrator_scheme="forward-euler",
    spatial_scheme="first_order",
    flux_scheme="hllc",
    n_halo_cells=1,
)

# create EquationManager
equation_manager = equation_manager_types.EquationManager(
    species=species,
    collision_integrals=None,
    reactions=None,
    numerics_config=numerics_config,
    boundary_condition="periodic",
)

# initial condition
U_init = equation_manager_utils.compute_U_from_primitives(
    Y_s=jnp.array([1.0]),
    rho=jnp.array([1.0]),
    u=jnp.array([0.0]),
    T=jnp.array([T_init]),
    T_V=jnp.array([T_V_init]),
    equation_manager=equation_manager,
)

# run simulation
U_field, t = profile_run(
    U_init=U_init,
    equation_manager=equation_manager,
    t_final=t_final,
    save_interval=save_interval,
)

## Heating (T > T_V)

In [1]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp

from compressible_1d import (
    equation_manager_types,
    equation_manager_utils,
    chemistry_types,
    chemistry_utils,
    numerics_types,
)

In [2]:
T_init = 10000.0  # K
T_V_init = 1000.0  # K

general_species_data_path = (
    "/home/hhoechter/tum/jaxfluids_internship/data/" "air_5_gnoffo.json"
)
equilibrium_enthalpy_data_path = (
    "/home/hhoechter/tum/jaxfluids_internship/data/"
    "air_5_gnoffo_equilibrium_enthalpy.json"
)
species_names = ["N2"]  # monoatomic case

dt = 1e-9  # s
t_final = 1e-5  # s
save_interval = 1  # every x steps
dx = 1e-4  # m  # TODO: show that this has no impact on the results in 0D

# ------------------------------ Simulation setup ------------------------------

# create species table
all_species = chemistry_utils.load_species_from_gnoffo(
    general_data_path=general_species_data_path,
    equilibrium_enthalpy=equilibrium_enthalpy_data_path,
)
selected_species = [species for species in all_species if species.name in species_names]
species = chemistry_types.SpeciesTable.from_species_list(selected_species)

# create NumericsConfig
numerics_config = numerics_types.NumericsConfig(
    dt=dt,
    dx=dx,
    integrator_scheme="forward-euler",
    spatial_scheme="first_order",
    flux_scheme="hllc",
    n_halo_cells=1,
)

# create EquationManager
equation_manager = equation_manager_types.EquationManager(
    species=species,
    collision_integrals=None,
    reactions=None,
    numerics_config=numerics_config,
    boundary_condition="periodic",
)

# initial condition
U_init = equation_manager_utils.compute_U_from_primitives(
    Y_s=jnp.array([1.0]),
    rho=jnp.array([1.0 / 25.0]),
    u=jnp.array([0.0]),
    T_tr=jnp.array([T_init]),
    T_V=jnp.array([T_V_init]),
    equation_manager=equation_manager,
)

# run simulation
U_field, t = equation_manager_module.run(
    U_init=U_init,
    equation_manager=equation_manager,
    t_final=t_final,
    save_interval=save_interval,
)
print("Simulation completed.")

# JIT-compiled version of extract_primitives_from_U (needs to be executed once per kernel uptime)
extract_primitives_from_U_jitted = jax.jit(
    equation_manager_utils.extract_primitives_from_U
)

# extract primitives and plot
Y_s, rho, T, T_V, p = jax.vmap(
    extract_primitives_from_U_jitted,
    in_axes=(0, None),
)(U_field, equation_manager)

from plotly import graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(x=t, y=T_V[:, 0], mode="lines+markers", name="T_V"))
fig.add_trace(go.Scatter(x=t, y=T[:, 0], mode="lines+markers", name="T"))
fig.update_xaxes(type="log")
fig.update_yaxes(range=[0, 10500])
fig.add_hline(y=7623.3, line_dash="dash", line_color="blue", name="T_eq Casseau")

fig.update_layout(
    title="Vibrational and Translational Temperature vs Time",
    xaxis_title="Time (s)",
    yaxis_title="Temperature (K)",
    showlegend=True,
)

Simulation completed.


In [60]:
# compute the degrees of freedom to plausibilize the equilibrium temperature
from compressible_1d import thermodynamic_relations
from compressible_1d import constants

T_V_check = jnp.array([T_V_init, T_V[-1, 0]])
e_vib = thermodynamic_relations.compute_e_vib_electronic(
    T_V=T_V_check,
    T_ref=298.0,
    T_limit_low=species.T_limit_low,
    T_limit_high=species.T_limit_high,
    parameters=species.enthalpy_coeffs,
    is_monoatomic=species.is_monoatomic,
    molar_masses=species.molar_masses,
)
dof_vib = (
    2 * e_vib * species.molar_masses[0] * 1e-3 / (constants.R_universal * T_V_check)
)
print("Degrees of freedom vibrational at init and final:", dof_vib)

Degrees of freedom vibrational at init and final: [[0.2087732 1.6868243]]


In [4]:
from compressible_1d import source_terms, equation_manager_utils

Y_s, rho, T, T_v, p = extract_primitives_from_U_jitted(U_init, equation_manager)

tau_s = source_terms.compute_relaxation_time(Y_s, rho, T, T_v, p, equation_manager)

print(tau_s)

# Q_dot_v = source_terms.compute_vibrational_relaxation(U_init, equation_manager)
# print(Q_dot_v)

[[1.7390674e-08]]


In [21]:
from plotly import graph_objects as go
import pandas as pd

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=t,
        y=T_V[:, 0],
        mode="lines",
        name="T_V",
        line=dict(shape="spline", smoothing=1.0),
    )
)
fig.add_trace(
    go.Scatter(
        x=t, y=T[:, 0], mode="lines", name="T", line=dict(shape="spline", smoothing=1.0)
    )
)

fig.add_hline(y=7623.3, name="T_eq Casseau", line_dash="dash", line_color="black")

# Load and plot reference data from CSV
csv_path = "/home/hhoechter/tum/jaxfluids_internship/experiments/heatbath_0d/casseau_figure_3_1.csv"
df = pd.read_csv(
    csv_path, skiprows=1
)  # Skip the first header row, use X,Y,X,Y... as columns

# Read first row to get the dataset names
with open(csv_path, "r") as f:
    header_line = f.readline().strip()
dataset_names = [name for name in header_line.split(",") if name]

# Group datasets by prefix (e.g., "modified", "hy2foam_default", "monaco")
prefixes = []
for name in dataset_names:
    # Extract prefix by removing _t_v or _t_tr suffix
    if "_t_v" in name:
        prefix = name.replace("_t_v", "")
    elif "_t_tr" in name:
        prefix = name.replace("_t_tr", "")
    else:
        prefix = name
    if prefix not in prefixes:
        prefixes.append(prefix)

# Define colors for each prefix
colors = ["green", "red", "purple", "orange", "brown", "pink"]
prefix_colors = {prefix: colors[i % len(colors)] for i, prefix in enumerate(prefixes)}

# Plot each dataset (columns come in pairs: X, Y for each dataset)
for i, name in enumerate(dataset_names):
    x_col = i * 2  # X column index
    y_col = i * 2 + 1  # Y column index

    if x_col >= len(df.columns) or y_col >= len(df.columns):
        continue

    x_data = pd.to_numeric(df.iloc[:, x_col], errors="coerce").dropna().values
    y_data = pd.to_numeric(df.iloc[:, y_col], errors="coerce").dropna().values

    # Use minimum length in case of mismatched data
    min_len = min(len(x_data), len(y_data))
    x_data = x_data[:min_len]
    y_data = y_data[:min_len]

    # Find prefix for this dataset
    if "_t_v" in name:
        prefix = name.replace("_t_v", "")
    elif "_t_tr" in name:
        prefix = name.replace("_t_tr", "")
    else:
        prefix = name

    fig.add_trace(
        go.Scatter(
            x=x_data,
            y=y_data * 1000,  # Scale Y values (they appear to be in units of 1000 K)
            mode="markers+lines",
            line=dict(dash="dot", smoothing=1.0, shape="spline"),
            name=name,
            marker=dict(color=prefix_colors.get(prefix, "gray")),
        )
    )

fig.update_xaxes(type="log", exponentformat="power", showexponent="all", showgrid=True)
fig.update_yaxes(range=[0, 10500], showgrid=True)
fig.update_layout(
    template="simple_white",
    title="Energy Relaxation Correlation with Casseau for nonreacting N2",
    xaxis_title="Time (s)",
    yaxis_title="Temperature (K)",
    legend=dict(
        x=0.95,
        y=0.05,
        xanchor="right",
        yanchor="bottom",
        borderwidth=1,
        bordercolor="black",
    ),
    showlegend=True,
    width=800,
    height=600,
)

### Conclusion
Good agreement with the results published by Casseau. The non-constant difference of T_V to the default hy2foam configuration potentially stems from the fact that hy2foam keeps the pressure constant at 1atm while in my setup pressure varies between 1.2 and 0.9 atm while the density is kept constant. The discontinuity of T_V at 7x10^-7 originates from a discontinuity in the enthalpy fits provided by Gnoffo. 

In [27]:
from plotly import graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(x=t, y=p[:, 0], mode="lines", name="Pressure"))
fig.update_xaxes(type="log", exponentformat="power", showexponent="all", showgrid=True)
fig.update_yaxes(showgrid=True)
fig.update_layout(
    template="simple_white",
    title="Pressure vs Time",
    xaxis_title="Time (s)",
    yaxis_title="Pressure (Pa)",
    legend=dict(
        x=0.95,
        y=0.95,
        xanchor="right",
        yanchor="bottom",
        borderwidth=1,
        bordercolor="black",
    ),
    showlegend=True,
    width=800,
    height=600,
)
fig.show()

## Cooling (T < T_V)