In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

from corrector_src.model._cnn_mhd_corrector import CorrectorCNN
import jax.numpy as jnp
import jax
import equinox as eqx

In [5]:
# Create a test model
key = jax.random.PRNGKey(42)
test_model = CorrectorCNN(in_channels=8, hidden_channels=16, key=key)
test_model = eqx.tree_deserialise_leaves(
    "/export/home/jalegria/Thesis/jf1uids/experiments/experiment_1/2025-10-01_15-47-48_10/cnn_model.eqx",
    test_model,
)

# Test input
test_input = jnp.ones((8, 32, 32, 32))

# Test forward pass
output = test_model(test_input)
print(f"Output shape: {output.shape}")
print(f"Output norm: {jnp.linalg.norm(output)}")

# Test gradients
def simple_loss(model, x):
    return jnp.mean(model(x)**2)

grads = eqx.filter_grad(simple_loss)(test_model, test_input)
print(f"Gradient norms: {jax.tree.map(lambda x: jnp.linalg.norm(x) if hasattr(x, 'shape') else 0, grads)}")

Output shape: (8, 32, 32, 32)
Output norm: nan
Gradient norms: CorrectorCNN(
  layers=Sequential(
    layers=(
      Conv3d(
        num_spatial_dims=3,
        weight=f32[],
        bias=f32[],
        in_channels=8,
        out_channels=16,
        kernel_size=(3, 3, 3),
        stride=(1, 1, 1),
        padding=((1, 1), (1, 1), (1, 1)),
        dilation=(1, 1, 1),
        groups=1,
        use_bias=True,
        padding_mode='ZEROS'
      ),
      Lambda(fn=None),
      Conv3d(
        num_spatial_dims=3,
        weight=f32[],
        bias=f32[],
        in_channels=16,
        out_channels=16,
        kernel_size=(3, 3, 3),
        stride=(1, 1, 1),
        padding=((1, 1), (1, 1), (1, 1)),
        dilation=(1, 1, 1),
        groups=1,
        use_bias=True,
        padding_mode='ZEROS'
      ),
      Lambda(fn=None),
      Conv3d(
        num_spatial_dims=3,
        weight=f32[],
        bias=None,
        in_channels=16,
        out_channels=8,
        kernel_size=(3, 3, 3),
    

In [6]:
def print_gradient_norms(grads):
    """Properly print gradient norms for each layer"""
    def get_norm(x):
        if hasattr(x, 'shape') and x.size > 0:
            return float(jnp.linalg.norm(x))
        return 0.0
    
    print("Gradient norms by layer:")
    for i, layer in enumerate(grads.layers.layers):
        if hasattr(layer, 'weight') and layer.weight is not None:
            weight_norm = get_norm(layer.weight)
            bias_norm = get_norm(layer.bias) if layer.bias is not None else 0.0
            print(f"  Layer {i} - Weight: {weight_norm:.6f}, Bias: {bias_norm:.6f}")
        else:
            print(f"  Layer {i} - No gradients (activation layer)")

# Test with proper printing
grads = eqx.filter_grad(simple_loss)(test_model, test_input)
print_gradient_norms(grads)

Gradient norms by layer:
  Layer 0 - Weight: 0.000000, Bias: 0.000000
  Layer 1 - No gradients (activation layer)
  Layer 2 - Weight: nan, Bias: 0.000000
  Layer 3 - No gradients (activation layer)
  Layer 4 - Weight: nan, Bias: 0.000000


In [9]:
from corrector_src.training.training_config import TrainingConfig
from corrector_src.model._cnn_mhd_corrector_options import (
    CNNMHDParams,
    CNNMHDconfig,
)
from jf1uids.option_classes.simulation_config import finalize_config
from jf1uids import time_integration

import corrector_src.data.blast_creation as blast
from corrector_src.utils.downaverage import downaverage_states
from corrector_src.training.loss import mse_loss
from jf1uids._physics_modules._cnn_mhd_corrector._cnn_mhd_corrector import _cnn_mhd_corrector


n_look_behind = 10
epochs = 1
num_cells_hr = 128
downsampling_factor = 4

In [11]:
training_config = TrainingConfig(
    compute_intermediate_losses=True,
    n_look_behind=n_look_behind,
    loss_weights=None,
    use_relative_error=False,
)

neural_net_params, neural_net_static = eqx.partition(test_model, eqx.is_array)

cnn_mhd_corrector_config = CNNMHDconfig(
    cnn_mhd_corrector=True, network_static=neural_net_static
)

cnn_mhd_corrector_params = CNNMHDParams(network_params=neural_net_params)

snapshot_losses = []
epoch_losses = []
randomized_vars = [1, 1, 1]

(
    initial_state,
    config,
    params,
    helper_data,
    registered_variables,
    _,
) = blast.randomized_initial_blast_state(num_cells_hr, randomized_vars)

config = finalize_config(config, initial_state.shape)

config = config._replace(cnn_mhd_corrector_config=cnn_mhd_corrector_config)
params = params._replace(cnn_mhd_corrector_params=cnn_mhd_corrector_params)


AttributeError: 'list' object has no attribute 'debug'

In [7]:
def test_corrector_gradients():
    """Test if the corrector itself produces gradients"""
    
    # Create test data matching your actual shapes
    test_primitive_state = jnp.ones((8, 34, 34, 34))  # Adjust to your actual padded shape
    test_target = test_primitive_state + 0.01 * jax.random.normal(jax.random.PRNGKey(42), test_primitive_state.shape)
    
    def corrector_loss_fn(network_params):
        # Create updated params
        updated_params = params._replace(
            cnn_mhd_corrector_params=params.cnn_mhd_corrector_params._replace(
                network_params=network_params
            )
        )
        
        # Run corrector
        corrected_state = _cnn_mhd_corrector(
            test_primitive_state,
            config,
            registered_variables,
            updated_params,
            jnp.array(0.01)  # timestep
        )
        
        # Simple loss
        return jnp.mean((corrected_state - test_target)**2)
    
    # Get trainable params from your actual model
    trainable_params, _ = eqx.partition(neural_net_params, eqx.is_array)
    
    # Compute gradients
    loss_val = corrector_loss_fn(trainable_params)
    grads = jax.grad(corrector_loss_fn)(trainable_params)
    
    print(f"Corrector test - Loss: {loss_val:.6f}")
    print("Corrector gradients:")
    print_gradient_norms(grads)
    
    return grads

corrector_grads = test_corrector_gradients()

NameError: name 'neural_net_params' is not defined

In [None]:
from jf1uids.option_classes.simulation_config import (
    BACKWARDS,
    CARTESIAN,
    FORWARDS,
    STATE_TYPE,
)
from jf1uids.fluid_equations.total_quantities import (
    calculate_internal_energy,
    calculate_total_mass,
)
from jf1uids.fluid_equations.total_quantities import (
    calculate_total_energy,
    calculate_kinetic_energy,
    calculate_gravitational_energy,
)

from jf1uids.data_classes.simulation_snapshot_data import SnapshotData
from jf1uids.time_stepping._timestep_estimator import (
    _cfl_time_step,
    _source_term_aware_time_step,
)


In [18]:
def unpad_state(state):
    """Helper function to remove ghost cells."""
    if config.geometry == CARTESIAN:
        if config.dimensionality == 1:
            return jax.lax.slice_in_dim(state, 2, state.shape[1] - 2, axis=1)
        elif config.dimensionality == 2:
            unpad_state = jax.lax.slice_in_dim(state, 2, state.shape[1] - 2, axis=1)
            return jax.lax.slice_in_dim(
                unpad_state, 2, unpad_state.shape[2] - 2, axis=2
            )
        elif config.dimensionality == 3:
            unpad_state = jax.lax.slice_in_dim(state, 2, state.shape[1] - 2, axis=1)
            unpad_state = jax.lax.slice_in_dim(
                unpad_state, 2, unpad_state.shape[2] - 2, axis=2
            )
            return jax.lax.slice_in_dim(
                unpad_state, 2, unpad_state.shape[3] - 2, axis=3
            )
    return state


def update_simulation_data(time, state, sim_data, step_idx):
    """Update simulation data at given step."""
    unpadded_state = unpad_state(state)

    time_points = sim_data.time_points.at[step_idx].set(time)
    states = sim_data.states.at[step_idx].set(unpadded_state)

    total_mass = sim_data.total_mass.at[step_idx].set(
        calculate_total_mass(unpadded_state, helper_data, config)
    )
    total_energy = sim_data.total_energy.at[step_idx].set(
        calculate_total_energy(
            unpadded_state,
            helper_data,
            params.gamma,
            params.gravitational_constant,
            config,
            registered_variables,
        )
    )
    internal_energy = sim_data.internal_energy.at[step_idx].set(
        calculate_internal_energy(
            unpadded_state,
            helper_data,
            params.gamma,
            config,
            registered_variables,
        )
    )
    kinetic_energy = sim_data.kinetic_energy.at[step_idx].set(
        calculate_kinetic_energy(
            unpadded_state, helper_data, config, registered_variables
        )
    )

    if config.self_gravity:
        gravitational_energy = sim_data.gravitational_energy.at[step_idx].set(
            calculate_gravitational_energy(
                unpadded_state,
                helper_data,
                params.gravitational_constant,
                config,
                registered_variables,
            )
        )
    else:
        gravitational_energy = sim_data.gravitational_energy

    current_checkpoint = step_idx + 1

    return sim_data._replace(
        time_points=time_points,
        states=states,
        current_checkpoint=current_checkpoint,
        total_mass=total_mass,
        total_energy=total_energy,
        internal_energy=internal_energy,
        kinetic_energy=kinetic_energy,
        gravitational_energy=gravitational_energy,
    )


In [None]:
test_primitive_state = jnp.ones((8, 34, 34, 34))
test_target = test_primitive_state + 0.5 * jax.random.normal(jax.random.PRNGKey(42), test_primitive_state.shape)
if config.geometry == CARTESIAN:
    original_shape = initial_state.shape

    if config.dimensionality == 1:
        initial_state = jnp.pad(initial_state, ((0, 0), (2, 2)), mode="edge")
    elif config.dimensionality == 2:
        initial_state = jnp.pad(
            initial_state, ((0, 0), (2, 2), (2, 2)), mode="edge"
        )
    elif config.dimensionality == 3:
        initial_state = jnp.pad(
            initial_state, ((0, 0), (2, 2), (2, 2), (2, 2)), mode="edge"
        )


total_steps = 2
full_time_points = jnp.zeros(total_steps)
full_states = jnp.zeros((total_steps, *original_shape))
full_total_mass = jnp.zeros(total_steps)
full_total_energy = jnp.zeros(total_steps)
full_internal_energy = jnp.zeros(total_steps)
full_kinetic_energy = jnp.zeros(total_steps)

if config.self_gravity:
    full_gravitational_energy = jnp.zeros(total_steps)
else:
    full_gravitational_energy = None

full_sim_data = SnapshotData(
    time_points=full_time_points,
    states=full_states,
    total_mass=full_total_mass,
    total_energy=full_total_energy,
    internal_energy=full_internal_energy,
    kinetic_energy=full_kinetic_energy,
    gravitational_energy=full_gravitational_energy,
    current_checkpoint=0,
)


def test_corrector_with_larger_target(carry):
    (
        state,
        time,
        sim_data,
        network_params,
        lag_data,
        opt_state,
    ) = carry

    def corrector_loss_fn(network_params, carry):
        (
            state,
            time,
            sim_data,
            network_params,
            lag_data,
            opt_state,
        ) = carry

        if not config.fixed_timestep:
            if config.source_term_aware_timestep:
                dt = jax.lax.stop_gradient(
                    _source_term_aware_time_step(
                        state,
                        config,
                        updated_params,
                        helper_data,
                        registered_variables,
                        current_time,
                    )
                )
            else:
                dt = jax.lax.stop_gradient(
                    _cfl_time_step(
                        state,
                        config.grid_spacing,
                        params.dt_max,
                        params.gamma,
                        config,
                        registered_variables,
                        params.C_cfl,
                    )
                )
        else:
            dt = jnp.asarray(params.t_end / config.num_timesteps)

        updated_params = params._replace(
            cnn_mhd_corrector_params=params.cnn_mhd_corrector_params._replace(
                network_params=network_params
            )
        )
        
        corrected_state = _cnn_mhd_corrector(
            test_primitive_state,
            config,
            registered_variables,
            updated_params,
            jnp.array(0.1)  # Larger timestep
        )
        
        return jnp.mean((corrected_state - test_target)**2)
    
    trainable_params, _ = eqx.partition(neural_net_params, eqx.is_array)
    
    loss_val = corrector_loss_fn(trainable_params, carry)
    grads = jax.grad(corrector_loss_fn)(trainable_params)
    
    print(f"Larger target test - Loss: {loss_val:.6f}")
    print("Gradients with larger target:")
    print_gradient_norms(grads)

test_corrector_with_larger_target()

Larger target test - Loss: 0.250621
Gradients with larger target:
Gradient norms by layer:
  Layer 0 - Weight: 0.076470, Bias: 0.005137
  Layer 1 - No gradients (activation layer)
  Layer 2 - Weight: 0.050064, Bias: 0.010012
  Layer 3 - No gradients (activation layer)
  Layer 4 - Weight: 0.040318, Bias: 0.000398


In [23]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

# Mock your data structures and functions
def mock_simulation_step(state, network_params, step_idx):
    """Mock version of your simulation step"""
    # Simulate what your physics modules do
    correction = network_params(state)  # Direct network call for simplicity
    corrected_state = state + 0.01 * correction  # time_step * correction
    return corrected_state

def mock_loss_function(pred_states, gt_states, training_config):
    """Mock your loss function"""
    return jnp.mean((pred_states - gt_states)**2)

# Create mock data
key = jax.random.PRNGKey(42)
state_shape = (8, 32, 32, 32)  # Adjust to your actual shape
data_lag = 3
total_steps = 5

# Mock initial state
initial_state = jnp.ones(state_shape)

# Mock target data (slightly different from what simulation would produce)
target_data = jnp.ones((total_steps, *state_shape))
for i in range(total_steps):
    noise = 0.1 * jax.random.normal(jax.random.split(key)[i], state_shape)
    target_data = target_data.at[i].set(initial_state + noise)

# Mock network (simplified version of your CorrectorCNN)
class MockCorrectorCNN(eqx.Module):
    layers: eqx.nn.Sequential
    
    def __init__(self, in_channels, hidden_channels, *, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = eqx.nn.Sequential([
            eqx.nn.Conv3d(in_channels, hidden_channels, 3, padding=1, key=key1),
            eqx.nn.Lambda(jax.nn.relu),
            eqx.nn.Conv3d(hidden_channels, in_channels, 3, padding=1, key=key2),
        ])
    
    def __call__(self, x):
        return self.layers(x)

# Initialize network and optimizer
network = MockCorrectorCNN(8, 16, key=key)
trainable_params, static_params = eqx.partition(network, eqx.is_array)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(trainable_params)

# Storage for simulation data
simulation_states = jnp.zeros((total_steps, *state_shape))
current_state = initial_state

def print_gradient_norms(grads, step_name=""):
    """Helper to print gradient norms"""
    def get_norm(x):
        if hasattr(x, 'shape') and x.size > 0:
            return float(jnp.linalg.norm(x))
        return 0.0
    
    print(f"{step_name} Gradient norms:")
    layer_idx = 0
    for layer in grads.layers.layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            weight_norm = get_norm(layer.weight)
            bias_norm = get_norm(layer.bias) if layer.bias is not None else 0.0
            print(f"  Layer {layer_idx} - Weight: {weight_norm:.8f}, Bias: {bias_norm:.8f}")
            layer_idx += 1
        else:
            print(f"  Layer {layer_idx} - Activation layer (no gradients)")
            layer_idx += 1

# Single training step that mimics your original code
def single_training_step(step_idx, current_state, simulation_states, trainable_params, opt_state):
    
    print(f"\n=== STEP {step_idx} ===")
    
    def evolve_loss_fn(network_params_current):
        # Use the passed parameters directly
        
        # Reconstruct full model
        full_model = eqx.combine(network_params_current, static_params)
        
        # Simulate one step
        new_state = mock_simulation_step(current_state, full_model, step_idx)
        
        # Update simulation data storage
        updated_states = simulation_states.at[step_idx].set(new_state)
        
        # Compute loss using data_lag approach (like your original code)
        end_idx = step_idx + 1
        start_idx = jnp.maximum(0, end_idx - data_lag)
        actual_length = end_idx - start_idx
        
        # Get predicted and ground truth states
        predicted_states = jax.lax.dynamic_slice_in_dim(
            updated_states, start_idx, data_lag, axis=0
        )
        ground_truth_states = jax.lax.dynamic_slice_in_dim(
            target_data, start_idx, data_lag, axis=0
        )
        
        # Handle masking for early steps
        if actual_length < data_lag:
            mask = jnp.arange(data_lag) < actual_length
            expanded_mask = mask.reshape(-1, *([1] * (predicted_states.ndim - 1)))
            masked_pred = predicted_states * expanded_mask
            masked_gt = ground_truth_states * expanded_mask
            loss = mock_loss_function(masked_pred, masked_gt, None)
            print(f"  Using masked loss (length {actual_length}/{data_lag})")
        else:
            loss = mock_loss_function(predicted_states, ground_truth_states, None)
            print(f"  Using full loss (length {data_lag})")
        
        # Debug info (avoid printing traced values during gradient computation)
        # These will only print during the sensitivity test, not during grad computation
        state_norm = jnp.linalg.norm(new_state)
        correction_effect = jnp.linalg.norm(new_state - current_state)
        pred_norm = jnp.linalg.norm(predicted_states)
        gt_norm = jnp.linalg.norm(ground_truth_states)
        
        # Update state for next iteration (return new values)
        return loss, new_state, updated_states, state_norm, correction_effect, pred_norm, gt_norm
    
    # Test loss sensitivity to parameters
    print("Testing parameter sensitivity:")
    original_loss, current_state, simulation_states, state_norm, correction_effect, pred_norm, gt_norm = evolve_loss_fn(trainable_params)
    
    # Print debug info from sensitivity test
    print(f"  State norm: {state_norm:.6f}")
    print(f"  Correction effect: {correction_effect:.6f}")
    print(f"  Predicted norm: {pred_norm:.6f}")
    print(f"  Target norm: {gt_norm:.6f}")
    print(f"  Loss value: {original_loss:.8f}")
    
    # Reset state for perturbed test
    temp_state = current_state
    temp_sim_states = simulation_states
    
    # Perturb parameters slightly
    perturbed_params = jax.tree.map(
        lambda x: x + 0.001 * jax.random.normal(jax.random.PRNGKey(step_idx), x.shape), 
        trainable_params
    )
    perturbed_loss, _, _, _, _, _, _ = evolve_loss_fn(perturbed_params)
    
    # Restore state
    current_state = temp_state
    simulation_states = temp_sim_states
    
    print(f"  Original loss: {original_loss:.8f}")
    print(f"  Perturbed loss: {perturbed_loss:.8f}")
    print(f"  Sensitivity: {abs(perturbed_loss - original_loss):.8f}")
    
    if abs(perturbed_loss - original_loss) < 1e-12:
        print("  WARNING: Loss is not sensitive to parameter changes!")
    
    # Compute gradients (need a function that only returns loss)
    print("Computing gradients...")
    def loss_only_fn(network_params_current):
        loss, _, _, _, _, _, _ = evolve_loss_fn(network_params_current)
        return loss
        
    loss_value, grads = eqx.filter_value_and_grad(loss_only_fn)(trainable_params)
    
    # Print gradient information
    print_gradient_norms(grads, f"Step {step_idx}")
    
    # Check if gradients are zero
    total_grad_norm = jax.tree.reduce(
        lambda acc, x: acc + jnp.sum(x**2) if hasattr(x, 'shape') else acc, 
        grads, 0.0
    )
    print(f"  Total gradient norm: {jnp.sqrt(total_grad_norm):.8f}")
    
    if jnp.sqrt(total_grad_norm) < 1e-10:
        print("  WARNING: Gradients are essentially zero!")
    
    # Update parameters
    updates, new_opt_state = optimizer.update(grads, opt_state, trainable_params)
    new_trainable_params = eqx.apply_updates(trainable_params, updates)
    
    # Check parameter change
    param_diffs = jax.tree.map(lambda x, y: jnp.sum((x - y)**2) if hasattr(x, 'shape') else 0.0, 
                               trainable_params, new_trainable_params)
    param_change = jax.tree.reduce(lambda acc, x: acc + x, param_diffs, 0.0)
    print(f"  Parameter change norm: {jnp.sqrt(param_change):.8f}")
    
    # Update global state (return new values)
    return loss_value, total_grad_norm, new_trainable_params, new_opt_state, current_state, simulation_states

# Run one training step
print("Running single training step debug...")
loss, grad_norm, trainable_params, opt_state, current_state, simulation_states = single_training_step(
    0, current_state, simulation_states, trainable_params, opt_state
)

print(f"\nFinal Summary:")
print(f"Loss: {loss:.8f}")
print(f"Total gradient norm: {jnp.sqrt(grad_norm):.8f}")

# Test a few more steps to see evolution
for step in range(1, min(3, total_steps)):
    loss, grad_norm, trainable_params, opt_state, current_state, simulation_states = single_training_step(
        step, current_state, simulation_states, trainable_params, opt_state
    )
    print(f"Step {step} - Loss: {loss:.8f}, Grad norm: {jnp.sqrt(grad_norm):.8f}")

Running single training step debug...

=== STEP 0 ===
Testing parameter sensitivity:
  Using masked loss (length 1/3)
  State norm: 512.425329
  Correction effect: 0.652386
  Predicted norm: 512.425329
  Target norm: 891.236440
  Loss value: 0.00333105
  Using masked loss (length 1/3)
  Original loss: 0.00333105
  Perturbed loss: 0.00333268
  Sensitivity: 0.00000163
Computing gradients...
  Using masked loss (length 1/3)
Step 0 Gradient norms:
  Layer 0 - Weight: 0.00004431, Bias: 0.00000312
  Layer 1 - Activation layer (no gradients)
  Layer 2 - Weight: 0.00002060, Bias: 0.00000597
  Total gradient norm: 0.00004933
  Parameter change norm: 0.05416974

Final Summary:
Loss: 0.00333271
Total gradient norm: 0.00004933

=== STEP 1 ===
Testing parameter sensitivity:
  Using masked loss (length 2/3)
  State norm: 512.640266
  Correction effect: 0.335387
  Predicted norm: 724.830850
  Target norm: 891.236440
  Loss value: 0.00667915
  Using masked loss (length 2/3)
  Original loss: 0.00667915