In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "9"
import jax
import jax.numpy as jnp
from corrector_src.model._cnn_mhd_corrector import CorrectorCNN
from corrector_src.model._cnn_mhd_corrector_options import (
    CNNMHDParams,
    CNNMHDconfig,
)
import equinox as eqx
import corrector_src.data.blast_creation as blast
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from jf1uids.option_classes.simulation_config import finalize_config
from jf1uids import time_integration
from corrector_src.utils.downaverage import downaverage_states
 

In [2]:
with initialize(version_base=None, config_path="../../../configs"):
    cfg = compose(
        config_name="config",
    )
    print(cfg)


{'experiment': {'name': 'experiment_1'}, 'training': {'epochs': 300, 'n_look_behind': 5, 'learning_rate': 0.0001, 'precomputed_data': True, 'return_full_sim': False, 'return_full_sim_epoch_interval': 10, 'rng_key': 60}, 'data': {'hr_res': 64, 'downscaling_factor': 2, 'randomizer_1': [0.5, 1.5], 'randomizer_2': [0.5, 1.5], 'randomizer_3': [0.5, 1.5], 'num_snapshots': 50, 'num_checkpoints': 50}, 'models': {'_target_': 'corrector_src.model._cnn_mhd_corrector.CorrectorCNN', 'in_channels': 8, 'hidden_channels': 16}}


In [3]:
model = CorrectorCNN(
    in_channels=8,
    hidden_channels=16,
    key=jax.random.PRNGKey(42),
)
model = eqx.tree_deserialise_leaves(
    os.path.abspath(
        "/export/home/jalegria/Thesis/jf1uids/experiments/experiment_1/2025-09-29_14-38-20_10/cnn_model.eqx"
    ),
    model,
)

neural_net_params, neural_net_static = eqx.partition(model, eqx.is_array)

cnn_mhd_corrector_config = CNNMHDconfig(
    cnn_mhd_corrector=True, network_static=neural_net_static
)

cnn_mhd_corrector_params = CNNMHDParams(network_params=neural_net_params)


In [14]:
(
    initial_state,
    config,
    params,
    helper_data,
    registered_variables,
    randomized_output_vars,
) = blast.randomized_initial_blast_state(
    cfg.data.hr_res, cfg.data, [1, 1, 1]
)

config = finalize_config(config, initial_state.shape)
final_states_hr = time_integration(
    initial_state, config, params, helper_data, registered_variables
)
ne_states = final_states_hr.states
ne_states = downaverage_states(ne_states, cfg.data.downscaling_factor)


KeyboardInterrupt: 

In [15]:
from jf1uids import SimulationConfig
from jf1uids.option_classes.simulation_config import BACKWARDS, HLL, BoundarySettings
from jf1uids import get_helper_data
from jf1uids import get_registered_variables
from jf1uids import CodeUnits
from astropy import units as u
from jf1uids import SimulationParams


In [6]:
adiabatic_index = 5 / 3
box_size = 1.0
fixed_timestep = True
dt_max = 0.1
mhd = True
randomizers = [1, 1, 1]
# setup simulation config
config = SimulationConfig(
    runtime_debugging=False,
    first_order_fallback=False,
    progress_bar=False,
    dimensionality=3,
    num_ghost_cells=2,
    box_size=box_size,
    num_cells=32,
    mhd=mhd,
    fixed_timestep=fixed_timestep,
    differentiation_mode=BACKWARDS,
    riemann_solver=HLL,
    limiter=0,
    return_snapshots=True,
    num_snapshots=cfg.data.num_snapshots,
    boundary_settings=BoundarySettings(),
    num_checkpoints=cfg.data.num_checkpoints,
    # boundary_settings=BoundarySettings(
    #    x=BoundarySettings1D(PERIODIC_BOUNDARY, PERIODIC_BOUNDARY),
    #    y=BoundarySettings1D(PERIODIC_BOUNDARY, PERIODIC_BOUNDARY),
    #    z=BoundarySettings1D(PERIODIC_BOUNDARY, PERIODIC_BOUNDARY),
    # ),
)

helper_data = get_helper_data(config)
registered_variables = get_registered_variables(config)

# setup the unit system
code_length = 3 * u.parsec
code_mass = 1 * u.M_sun
code_velocity = 100 * u.km / u.s
code_units = CodeUnits(code_length, code_mass, code_velocity)

# time domain
C_CFL = 0.4  # Courant-Friedrichs-Lewy number
t_final = 1.0 * 1e4 * u.yr
t_end = t_final.to(code_units.code_time).value

# set the simulation parameters
params = SimulationParams(
    C_cfl=C_CFL,
    dt_max=dt_max,
    gamma=adiabatic_index,
    t_end=t_end,
)

grid_spacing = config.box_size / config.num_cells
x = jnp.linspace(
    grid_spacing / 2, config.box_size - grid_spacing / 2, config.num_cells
)
y = jnp.linspace(
    grid_spacing / 2, config.box_size - grid_spacing / 2, config.num_cells
)
z = jnp.linspace(
    grid_spacing / 2, config.box_size - grid_spacing / 2, config.num_cells
)

X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")

r = helper_data.r
# Initialize state
rho = jnp.ones_like(X)
P = jnp.ones_like(X) * 0.1
r_inj = 0.1 * box_size * randomizers[0]
p_inj = 10.0 * randomizers[1]
P = jnp.where(r**2 < r_inj**2, p_inj, P)

u_x = jnp.zeros_like(X)
u_y = jnp.zeros_like(X)
u_z = jnp.zeros_like(X)

B_0 = 1 / jnp.sqrt(2) * randomizers[2]
B_x = B_0 * jnp.ones_like(X)
B_y = B_0 * jnp.ones_like(X)
B_z = jnp.zeros_like(X)
print(
    jnp.argwhere(jnp.isnan(P)),
    jnp.argwhere(jnp.isnan(u_x)),
    jnp.argwhere(jnp.isnan(u_y)),
    jnp.argwhere(jnp.isnan(u_z)),
    jnp.argwhere(jnp.isnan(B_x)),
    jnp.argwhere(jnp.isnan(B_y)),
    jnp.argwhere(jnp.isnan(B_z)),
    jnp.argwhere(jnp.isnan(rho)),
)



[] [] [] [] [] [] [] []


In [9]:
if jnp.isnan(P).any():
    print("nan found in initial state")

In [33]:
(
    initial_state,
    config,
    params,
    helper_data,
    registered_variables,
    randomized_output_vars,
) = blast.randomized_initial_blast_state(
    cfg.data.hr_res // cfg.data.downscaling_factor , cfg.data, [1, 1, 1]
)
config = finalize_config(config, initial_state.shape)

config = config._replace(cnn_mhd_corrector_config=cnn_mhd_corrector_config)
params = params._replace(cnn_mhd_corrector_params=cnn_mhd_corrector_params)

final_states_lr = time_integration(
    initial_state, config, params, helper_data, registered_variables
)
e_states = final_states_lr.states


In [4]:
(
    initial_state,
    config,
    params,
    helper_data,
    registered_variables,
    randomized_output_vars,
) = blast.randomized_initial_blast_state(
    cfg.data.hr_res // cfg.data.downscaling_factor, cfg.data, [1, 1, 1]
)

config = finalize_config(config, initial_state.shape)

config = config._replace(cnn_mhd_corrector_config=cnn_mhd_corrector_config)
params = params._replace(cnn_mhd_corrector_params=cnn_mhd_corrector_params)


In [25]:
from jf1uids._physics_modules.run_physics_modules import _run_physics_modules


In [6]:
dt = jnp.asarray(params.t_end / config.num_timesteps)


In [12]:
from jf1uids._physics_modules._cnn_mhd_corrector._cnn_mhd_corrector import (
    _cnn_mhd_corrector,
)
from corrector_src.model._cnn_mhd_corrector import _cnn_mhd_corrector_2d, _cnn_mhd_corrector_3d

primitive_state = _cnn_mhd_corrector_2d(
    initial_state, config, registered_variables, params, dt)

print(jnp.mean(primitive_state - initial_state))
print(jnp.shape(primitive_state))

-7.555437734282688e-07
(8, 32, 32, 32)


In [27]:
state = _run_physics_modules(
    initial_state,
    dt,
    config,
    params,
    helper_data,
    registered_variables,
    dt,
)

if jnp.isnan(state).any():
    print("found nan")

In [28]:
from jf1uids._state_evolution.evolve_state import _evolve_state


In [29]:
# update state
state = _evolve_state(
    state,
    dt,
    params.gamma,
    params.gravitational_constant,
    config,
    helper_data,
    registered_variables,
)
if jnp.isnan(state).any():
    print("found nan")

In [54]:
jnp.shape(e_states[0])
jnp.argwhere(jnp.isnan(e_states[0]))


Array([], shape=(0, 4), dtype=int64)

In [None]:
for i in jnp.arange(len(e_states)):
    if jnp.isnan(e_states[i]).any()


SyntaxError: incomplete input (2018819236.py, line 2)

In [11]:
from jf1uids._state_evolution.evolve_state import _evolve_state
from jf1uids._physics_modules.run_physics_modules import _run_physics_modules


corrector_src.model._cnn_mhd_corrector.CorrectorCNN

In [18]:
import jax
import jax.numpy as jnp
import equinox as eqx

# Flatten all parameters to a list
all_params = jax.tree_util.tree_leaves(neural_net_params)

# Check for NaNs in any parameter
nan_found = any(jnp.isnan(p).any() for p in all_params)

print("NaNs in parameters:", nan_found)


NaNs in parameters: False
