# Colab Notebook: Debugging JAX & Flax NNX - Exercises

Welcome! This notebook contains exercises to help you practice the JAX and Flax NNX debugging techniques discussed in the lecture. If you're a PyTorch user you'll find some concepts familiar, while others are specific to JAX's compiled nature. Remember to run the setup cells first!

First, let's install the necessary libraries and import them.

In [None]:
# Start by updating the protobuf version, which may require a restart

!pip install -U protobuf

In [None]:
!pip install -Uq flax jax jaxlib chex optax

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import flax
from flax import nnx
import chex
import pdb # Python's built-in debugger
import functools # For functools.partial
import optax # For optimizers, though we won't train deeply

chex.set_n_cpu_devices(8) # Fake an environment with 8 CPUs.  This must be done before any JAX operations
print(f"Fake devices: {jax.devices()}")

# NOTE for Flax v0.11+: The flax.nnx.Optimizer API has changed.
# It now requires a `wrt` argument at construction (e.g., wrt=nnx.Param)
# and the update call is now `optimizer.update(model, grads)` instead of `optimizer.update(grads)`.

# Helper to clear chex trace counter for repeatable examples
chex.clear_trace_counter()

print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}") # NNX is part of flax
print(f"Chex version: {chex.__version__}")
print(f"Device: {jax.devices()}")

## 1. "printf Debugging" in JAX: jax.debug.print()

JAX's JIT compilation means standard Python print() behaves differently inside JITted functions. It sees tracers during compilation, not runtime values. jax.debug.print() is the JAX-aware alternative.

### Exercise 1.1:
1. Uncomment and complete the line # YOUR CODE HERE in the compute_and_print function above.
2. Add a jax.debug.print() statement to display the runtime value of z.
3. Run the cell. Observe the outputs.
 - What does the standard print(y) show?
 - What do the jax.debug.print statements show for y and z? Why is this different?

In [None]:
@jit
def compute_and_print(x):
  y = x * 10
  print("Standard print (sees tracer):", y)
  jax.debug.print("jax.debug.print (sees runtime value for y): {y_val}", y_val=y, ordered=True)

  z = y / 2
  # Exercise 1.1: Add another jax.debug.print here to see the runtime value of 'z'
  # Make sure to give it a descriptive message and use the ordered=True argument.
  # YOUR CODE HERE

  return z

input_val = jnp.array(5.0)
print(f"Input value: {input_val}\n")
output_val = compute_and_print(input_val)
print(f"\nFinal output: {output_val}")

### Solution (for Exercise 1.1, after attempting):

In [None]:
# @jit
# def compute_and_print_solution(x):
#   y = x * 10
#   print("Standard print (sees tracer):", y)
#   jax.debug.print("jax.debug.print (sees runtime value for y): {y_val}", y_val=y, ordered=True)

#   z = y / 2
#   jax.debug.print("jax.debug.print (sees runtime value for z): {z_val}", z_val=z, ordered=True) # SOLUTION

#   return z

# input_val = jnp.array(5.0)
# print(f"Input value: {input_val}\n")
# output_val = compute_and_print_solution(input_val)
# print(f"\nFinal output: {output_val}")

Standard print shows a tracer object (e.g., Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>). This is because it executes during JAX's tracing phase. jax.debug.print shows the concrete numerical values (e.g., 50.0 for y, 25.0 for z) because it's embedded into the compiled computation graph and executes with runtime data.

## 2. Interactive Debugging in JIT: jax.debug.breakpoint()
jax.debug.breakpoint() is JAX's equivalent of pdb.set_trace() for use inside transformed functions. It pauses execution and gives you a (jaxdb) prompt.

### Exercise 2.1:
1. Uncomment and complete the line # YOUR CODE HERE in the interact_with_values function above.
2. Add jax.debug.breakpoint() where indicated.
3. Run the cell.
4. When execution pauses at the (jaxdb) prompt:
 - Inspect the value of y by typing p y and pressing Enter.
 - Continue execution by typing c and pressing Enter.
5. Note that jaxdb has a subset of pdb commands (e.g., stepping n or s is not available).

In [None]:
@jit
def interact_with_values(x):
  y = jnp.sin(x)
  jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)

  # Exercise 2.1: Place the breakpoint here.
  # YOUR CODE HERE

  z = jnp.cos(y)
  jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
  return z

input_angle = jnp.array(0.75)
print("Calling interact_with_values...")
result = interact_with_values(input_angle)
print(f"Result: {result}")

### Solution (for Exercise 2.1, after attempting):

In [None]:
# @jit
# def interact_with_values_solution(x):
#   y = jnp.sin(x)
#   jax.debug.print("Value of y before breakpoint: {y_val}", y_val=y)

#   jax.debug.breakpoint() # SOLUTION

#   z = jnp.cos(y)
#   jax.debug.print("Value of z after breakpoint: {z_val}", z_val=z)
#   return z

# input_angle = jnp.array(0.75)
# print("Calling interact_with_values...")
# result = interact_with_values_solution(input_angle)
# print(f"Result: {result}")

## 3. Back to Basics: Temporarily Disabling JIT with jax.disable_jit()
Sometimes, you need the full power of standard Python debugging tools. jax.disable_jit() allows JAX functions to execute eagerly.

### Exercise 3.1 & 3.2:
1. In `complex_calculation`, add pdb.set_trace() where indicated (# YOUR CODE HERE for 3.1).
2. First, try running the cell as is (with Scenario 1 uncommented and Scenario 2's call commented out). Observe what happens with pdb.set_trace() inside a JITted function.
3. Then, comment out Scenario 1.
4. In Scenario 2, within the with jax.disable_jit(): block, call `complex_calculation` (where # YOUR CODE HERE for 3.2 is) with value (try 0.1 first, then 5.0 to ensure the conditional is met) and threshold=0.5.
5. When pdb triggers:
 - Inspect a, b, and c.
 - Type c to continue.
6. Reflect: When would you use jax.disable_jit() over jax.debug.breakpoint()?

In [None]:
@jit
def complex_calculation(x, threshold):
  a = x * 2.0
  b = jnp.log(a)
  c = b + x
  # Imagine 'c' sometimes becomes NaN, and it's hard to see why.
  # We want to inspect 'a', 'b', and 'c' using standard pdb.
  if c > threshold: # This condition might be tricky under JIT
      # Exercise 3.1: Add a pdb.set_trace() here.
      # It will only work if JIT is disabled for this function call.
      # YOUR CODE HERE
      print("Inside conditional pdb trace") # This will print if pdb is hit
  d = jnp.sqrt(jnp.abs(c)) # abs to avoid NaNs from sqrt of negative
  return d

value = jnp.array(0.1) # Try with 0.1 then with 5.0

# Scenario 1: JIT enabled (pdb.set_trace() will be skipped or might error)
# print("--- Running WITH JIT (pdb will likely be skipped) ---")
# try:
#   result_jit = complex_calculation(value, threshold=0.5)
#   print(f"Result with JIT: {result_jit}")
# except Exception as e:
#   print(f"Scenario 1 Error:\n{e}\n")

# Scenario 2: JIT disabled
print("\n--- Running with JIT DISABLED for this block ---")
with jax.disable_jit():
  # Exercise 3.2: Call complex_calculation here with value and threshold=0.5
  # so that your pdb.set_trace() (from Ex 3.1) gets triggered.
  # YOUR CODE HERE
  pass # remove this pass

print("Finished disable_jit block.")

### Solution (for Exercise 3.1 & 3.2, after attempting):

In [None]:
# @jit
# def complex_calculation_solution(x, threshold):
#   a = x * 2.0
#   b = jnp.log(a)
#   c = b + x
#   if c > threshold:
#       pdb.set_trace() # SOLUTION 3.1
#       print("Inside conditional pdb trace")
#   d = jnp.sqrt(jnp.abs(c))
#   return d

# value_for_pdb = jnp.array(5.0) # This value will trigger the condition c > threshold

# # Scenario 1: JIT enabled (pdb.set_trace() will be skipped or might error)
# print("--- Running WITH JIT (pdb will likely be skipped) ---")
# try:
#   result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)
#   print(f"Result with JIT: {result_jit}")
# except Exception as e:
#   print(f"Scenario 1 Error:\n{e}\n")

# print("\n--- Running with JIT DISABLED for this block ---")
# with jax.disable_jit():
#   result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # SOLUTION 3.2
#   print(f"Result with JIT disabled: {result_no_jit}")
# print("Finished disable_jit block.")

You'd use jax.disable_jit() when jax.debug.breakpoint() is insufficient, e.g., when you need the full pdb features (like stepping), want to use an IDE debugger, or when jax.debug.breakpoint() itself doesn't give enough context. The trade-off is performance loss.

## 4. Automatic NaN Hunting: jax_debug_nans Flag
NaNs can be a nightmare. jax_debug_nans helps JAX pinpoint the exact operation causing them.

### Exercise 4.1 & 4.2:
1. In Scenario 1, uncomment the example call or create your own call to problematic_function_for_nans that results in a NaN (e.g., x = jnp.array(-1.0), divisor = jnp.array(1.0) or x = jnp.array(1.0), divisor = jnp.array(0.0)). Run and observe the error.
2. In Scenario 2:
 - Uncomment the line jax.config.update("jax_debug_nans", True).
 - Uncomment the example call or use the same NaN-causing inputs as in 4.1.
 - Run and observe the error message. Is it more helpful in pinpointing the source of the NaN?
 - Make sure the finally block runs to disable the flag.
3. Why is jax_debug_nans not enabled by default?

In [None]:
@jit
def problematic_function_for_nans(x, divisor):
  y = x * 100
  # This operation can cause NaN if divisor is 0 or x is negative and we take log
  z = jnp.log(y) / divisor # Potential NaN source
  return z + y

# Scenario 1: Run without jax_debug_nans
print("--- Scenario 1: Running without jax_debug_nans ---")
try:
  # Exercise 4.1: Call problematic_function_for_nans with inputs that cause a NaN
  # For example, x = jnp.array(-1.0), divisor = jnp.array(1.0)
  # OR x = jnp.array(1.0), divisor = jnp.array(0.0)
  # Observe the error. Is it specific?
  # YOUR CODE HERE
  # result1 = problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))
  # print(f"Result 1: {result1}")
  pass # remove this
except Exception as e:
  print(f"Caught exception (without jax_debug_nans): {e}\n")


# Scenario 2: Run WITH jax_debug_nans
print("--- Scenario 2: Running WITH jax_debug_nans ---")
# jax.config.update("jax_debug_nans", True) # Enable NaN debugging

try:
  # Exercise 4.2: Call problematic_function_for_nans again with the SAME NaN-causing inputs.
  # Observe the error now. Is it more specific?
  # YOUR CODE HERE
  # result2 = problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))
  # print(f"Result 2: {result2}")
  pass # remove this
except Exception as e:
  print(f"Caught exception (WITH jax_debug_nans): {e}\n")
finally:
  # jax.config.update("jax_debug_nans", False) # Disable after use
  print("jax_debug_nans has been disabled.")

### Solution (for Exercise 4.1 & 4.2, after attempting):

In [None]:
# @jit
# def problematic_function_for_nans_solution(x, divisor):
#   y = x * 100
#   z = jnp.log(y) / divisor
#   return z + y

# # Scenario 1: Run without jax_debug_nans
# print("--- Scenario 1: Running without jax_debug_nans ---")
# try:
#   # For example, x = jnp.array(-1.0), divisor = jnp.array(1.0)
#   # OR x = jnp.array(1.0), divisor = jnp.array(0.0)
#   result1 = problematic_function_for_nans_solution(jnp.array(-1.0), jnp.array(1.0)) # SOLUTION for 4.1
#   print(f"Result 1: {result1}")
# except Exception as e:
#   print(f"Caught exception (without jax_debug_nans): {e}\n")


# # Scenario 2: Run WITH jax_debug_nans
# print("--- Scenario 2: Running WITH jax_debug_nans ---")
# jax.config.update("jax_debug_nans", True) # Enable NaN debugging

# try:
#   result2 = problematic_function_for_nans_solution(jnp.array(-1.0), jnp.array(1.0)) # SOLUTION for 4.2
#   print(f"Result 2: {result2}")
# except Exception as e:
#   print(f"Caught exception (WITH jax_debug_nans): {e}\n")
# finally:
#   jax.config.update("jax_debug_nans", False) # Disable after use
#   print("jax_debug_nans has been disabled.")

Without jax_debug_nans, the error might be a generic NaN detection or occur later in the computation. With jax_debug_nans enabled, JAX re-runs the failing operations in eager mode and raises an error at the exact primitive operation that produced the NaN, making it much easier to find. It's not enabled by default because it adds overhead (checks and potential eager re-runs), significantly slowing down execution.

## 5. Inspecting Flax NNX Models: nnx.display()

**A Note on NNX Modules in Flax v0.11+:** In this version, `nnx.Module` and other NNX objects are now registered as JAX Pytrees. This means JAX transformations like `jax.jit` and `jax.vmap` can be used on them directly. However, if you use functions like `jax.tree.map` on a data structure containing NNX modules, they will be traversed by default. To treat them as leaves (the old behavior), you must use the `is_leaf` argument: `is_leaf=lambda x: isinstance(x, nnx.Pytree)`.

`nnx.display()` provides a clear view of your NNX Module's structure, parameters, and state.

### Exercise 5.1 - 5.4:
1. 5.1: In `SimpleNNXModel.__init__`, add a second `nnx.Linear` layer named `self.dense2` that maps from `dhidden` to `dout` features. Remember to provide `rngs`.
2. 5.2: In `SimpleNNXModel.__call__`, pass the intermediate `x` through `self.dense2` (if you added it).
3. 5.3: When instantiating `SimpleNNXModel`, ensure `din`, `dhidden`, and `dout` match your intended architecture (e.g., `dout=5` if your `dense2` outputs 5 features).
4. 5.4: Use `nnx.display(model)` to print the structure. Examine the output. Can you see both dense layers and their parameters (`kernel`, `bias`)? Can you see the `PReLU` parameters?

In [None]:
class SimpleNNXModel(nnx.Module):
  def __init__(self, din: int, dhidden: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.dense1 = nnx.Linear(din, dhidden, rngs=rngs)
    # Exercise 5.1: Add another Linear layer called 'dense2' (dhidden -> dout)
    # YOUR CODE HERE
    self.activation = nnx.relu # Example of a layer with its own parameters

  def __call__(self, x):
    x = self.dense1(x)
    x = nnx.relu(x)
    # Exercise 5.2: Pass x through 'dense2' if you added it.
    # YOUR CODE HERE
    x = self.activation(x)
    return x

# Initialize RNGs for parameters
key = jax.random.key(0)
model_rngs = nnx.Rngs(params=key)

# Instantiate the model
# Exercise 5.3: Update din, dhidden, dout if you changed the model structure
model = SimpleNNXModel(din=10, dhidden=20, dout=5, rngs=model_rngs)

# Display the model structure
print("--- Model Structure using nnx.display() ---")
# Exercise 5.4: Use nnx.display() to show the model's structure
# YOUR CODE HERE

# If you have treescope installed and are in a compatible environment (like Colab default),
# nnx.display() will give an interactive tree. Otherwise, it falls back to print.

### Solution (for Exercise 5.1-5.4, after attempting):

In [None]:
# class SimpleNNXModelSolution(nnx.Module):
#   def __init__(self, din: int, dhidden: int, dout: int, *, rngs: nnx.Rngs):
#     self.dense1 = nnx.Linear(din, dhidden, rngs=rngs)
#     self.dense2 = nnx.Linear(dhidden, dout, rngs=rngs) # SOLUTION 5.1
#     self.activation = nnx.relu

#   def __call__(self, x):
#     x = self.dense1(x)
#     x = nnx.relu(x)
#     x = self.dense2(x) # SOLUTION 5.2
#     x = self.activation(x)
#     return x

# # Initialize RNGs for parameters
# key = jax.random.key(0)
# model_rngs = nnx.Rngs(params=key)

# # Instantiate the model
# model_solution = SimpleNNXModelSolution(din=10, dhidden=20, dout=5, rngs=model_rngs) # SOLUTION 5.3 (dout adjusted)

# # Display the model structure
# print("--- Model Structure using nnx.display() ---")
# nnx.display(model_solution) # SOLUTION 5.4

## 6. Capturing Intermediate Values: nnx.sow()
Module.sow() allows you to "plant" intermediate values during the forward pass for later retrieval.

### Exercise 6.1 & 6.2:
1. 6.1: In ModelWithSow.__call__, after x1_act is computed, use self.sow(nnx.Intermediate, 'activation_layer1', x1_act) to store it.
2. 6.2: After running the model, retrieve the sown value. It will be an attribute on sow_model named activation_layer1. Access its .value and print its shape.
3. What would happen if you called sow multiple times with the same name within one forward pass (e.g., inside a loop)?

In [None]:
class ModelWithSow(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.dense1 = nnx.Linear(5, 10, rngs=rngs)
    self.dense2 = nnx.Linear(10, 3, rngs=rngs)

  def __call__(self, x):
    x1_act = self.dense1(x)
    x1_act = nnx.relu(x1_act)

    # Exercise 6.1: Use self.sow() to store the value of x1_act.
    # Use nnx.Intermediate as the variable_type and 'activation_layer1' as the name.
    # YOUR CODE HERE

    x2_out = self.dense2(x1_act)
    return x2_out

# Setup
key = jax.random.key(1)
model_sow_rngs = nnx.Rngs(params=key)
sow_model = ModelWithSow(rngs=model_sow_rngs)
dummy_input = jnp.ones((1, 5))

# Run the model
output = sow_model(dummy_input)

# Retrieve the sown value
# Exercise 6.2: Retrieve the 'activation_layer1' value from the sow_model instance.
# Remember it's stored as an attribute, and the actual data is in its .value property.
# Print the shape of the retrieved value.
# YOUR CODE HERE
# retrieved_activation = ...
# print(f"Shape of retrieved activation: {retrieved_activation.shape}") # Adjust if it's a tuple

### Solution (for Exercise 6.1 & 6.2, after attempting):

In [None]:
# class ModelWithSowSolution(nnx.Module):
#   def __init__(self, *, rngs: nnx.Rngs):
#     self.dense1 = nnx.Linear(5, 10, rngs=rngs)
#     self.dense2 = nnx.Linear(10, 3, rngs=rngs)

#   def __call__(self, x):
#     x1_act = self.dense1(x)
#     x1_act = nnx.relu(x1_act)

#     self.sow(nnx.Intermediate, 'activation_layer1', x1_act) # SOLUTION 6.1

#     x2_out = self.dense2(x1_act)
#     return x2_out

# # Setup
# key = jax.random.key(1)
# model_sow_rngs = nnx.Rngs(params=key)
# sow_model_solution = ModelWithSowSolution(rngs=model_sow_rngs)
# dummy_input = jnp.ones((1, 5))

# # Run the model
# output = sow_model_solution(dummy_input)

# # Retrieve the sown value
# retrieved_sown_value_obj = sow_model_solution.activation_layer1 # This is the Variable object
# retrieved_activation = retrieved_sown_value_obj.value          # This is the actual data (often a tuple)
# print(f"Retrieved activation (raw): {retrieved_activation}")
# # By default, sow appends to a tuple. So value is likely ((1,10))
# print(f"Shape of retrieved activation (first element): {retrieved_activation[0].shape}") # SOLUTION 6.2

If sow is called multiple times with the same name in one forward pass, by default, it appends each new value to a tuple stored in the .value property of the sown attribute.

## 7. Robustness with Chex Assertions
Chex provides powerful assertions for JAX code.

### Exercise 7.1.1:
1. Fill in the # YOUR CODE HERE section in process_image_data with the specified Chex static assertions.
2. Run the cell and observe how the assertions catch the errors for wrong_shape_data and wrong_type_data.

In [None]:
@jit
def process_image_data(image_batch: chex.Array):
  # Exercise 7.1.1: Add Chex assertions to verify:
  # 1. image_batch has a rank of 4 (e.g., Batch, Height, Width, Channels).
  # 2. image_batch has a dtype of jnp.float32.
  # 3. image_batch has a specific shape, e.g., (32, 224, 224, 3).
  #    You can use a placeholder for batch_size if needed: chex.assert_shape(image_batch, (None, 224, 224, 3))
  # YOUR CODE HERE

  # Dummy computation
  processed = image_batch * 2.0 - 1.0
  return processed

# Test cases
correct_data = jnp.ones((32, 224, 224, 3), dtype=jnp.float32)
wrong_shape_data = jnp.ones((32, 224, 3), dtype=jnp.float32) # Missing a dim
wrong_type_data = jnp.ones((32, 224, 224, 3), dtype=jnp.int32)

print("--- Testing with correct data ---")
try:
  _ = process_image_data(correct_data)
  print("Correct data processed successfully!")
except Exception as e:
  print(f"Error with correct data:\n{e}")

print("\n--- Testing with wrong shape data ---")
try:
  _ = process_image_data(wrong_shape_data)
  print("Wrong shape data processed successfully (this shouldn't happen if assertions are correct).")
except AssertionError as e:
  print(f"Caught expected AssertionError for wrong shape:\n{e}")

print("\n--- Testing with wrong type data ---")
try:
  _ = process_image_data(wrong_type_data)
  print("Wrong type data processed successfully (this shouldn't happen if assertions are correct).")
except AssertionError as e:
  print(f"Caught expected AssertionError for wrong type:\n{e}")

### Solution (for Exercise 7.1.1, after attempting):

In [None]:
# @jit
# def process_image_data_solution(image_batch: chex.Array):
#   chex.assert_rank(image_batch, 4)                             # SOLUTION
#   chex.assert_type(image_batch, jnp.float32)                   # SOLUTION
#   chex.assert_shape(image_batch, (None, 224, 224, 3))          # SOLUTION (using None for batch)

#   processed = image_batch * 2.0 - 1.0
#   return processed

# # Test cases (same as above)
# correct_data_sol = jnp.ones((32, 224, 224, 3), dtype=jnp.float32)
# wrong_shape_data_sol = jnp.ones((32, 224, 3), dtype=jnp.float32)
# wrong_type_data_sol = jnp.ones((32, 224, 224, 3), dtype=jnp.int32)

# print("--- SOLUTION: Testing with correct data ---")
# try:
#   _ = process_image_data_solution(correct_data_sol)
#   print("Correct data processed successfully!")
# except Exception as e:
#   print(f"Error with correct data:\n{e}")

# print("\n--- SOLUTION: Testing with wrong shape data ---")
# try:
#   _ = process_image_data_solution(wrong_shape_data_sol)
# except AssertionError as e:
#   print(f"Caught expected AssertionError for wrong shape:\n{e}")

# print("\n--- SOLUTION: Testing with wrong type data ---")
# try:
#   _ = process_image_data_solution(wrong_type_data_sol)
# except AssertionError as e:
#   print(f"Caught expected AssertionError for wrong type:\n{e}")

## 7.2. Performance Debugging: @chex.assert_max_traces()
Unintended JIT recompilations kill performance. @chex.assert_max_traces(n=N) helps detect this.

### Exercise 7.2.1 & 7.2.2:
1. 7.2.1: In process_dynamic_shape, add the @chex.assert_max_traces(n=1) decorator.
2. 7.2.2: Uncomment and complete the second call to process_dynamic_shape using an input array with a different shape (e.g., jnp.ones((3,3))).
3. Run the cell.
 - Observe that Scenario 1 (with static_argnums) passes because the shape information critical for compilation (shape_tuple) is static and doesn't change.
 - Observe that Scenario 2 should raise an AssertionError. Why does this happen?

In [None]:
chex.clear_trace_counter() # Reset counter for this specific example

# Scenario 1: Function with static argument for shape
@functools.partial(jit, static_argnums=(1,)) # shape_tuple is static
@chex.assert_max_traces(n=1)
def process_fixed_shape_staticarg(x: chex.Array, shape_tuple: tuple):
    chex.assert_shape(x, shape_tuple) # Check the shape matches
    return x * 2.0

print("--- Scenario 1: Static argnum, consistent shape tuple ---")
fixed_shape = (3, 4)
input_data_s1_c1 = jnp.ones(fixed_shape)
input_data_s1_c2 = jnp.zeros(fixed_shape) # Same shape, different values
_ = process_fixed_shape_staticarg(input_data_s1_c1, fixed_shape) # First call, traces
print("First call to process_fixed_shape_staticarg successful (traces).")
_ = process_fixed_shape_staticarg(input_data_s1_c2, fixed_shape) # Second call, reuses cache
print("Second call to process_fixed_shape_staticarg successful (reuses cache).")


# Scenario 2: Function where input shape might vary, leading to retracing if not handled
chex.clear_trace_counter() # Reset for this scenario

@jit
# Exercise 7.3.1: Add @chex.assert_max_traces(n=1) here
# YOUR CODE HERE
def process_dynamic_shape(x: chex.Array):
    # This function will be re-traced if 'x' shape changes between calls
    return x + jnp.sum(x) # Example op

print("\n--- Scenario 2: Varying input shapes ---")
try:
    print("Calling process_dynamic_shape with (2, 2)...")
    _ = process_dynamic_shape(jnp.ones((2, 2))) # First call, traces
    print("First call to process_dynamic_shape successful.")

    # Exercise 7.3.2: Call process_dynamic_shape with a DIFFERENT shape, e.g., (3,3).
    # This should trigger an AssertionError if assert_max_traces is working.
    print("Calling process_dynamic_shape with (3, 3)...")
    # YOUR CODE HERE
    # _ = process_dynamic_shape(jnp.ones((3, 3)))
    print("Second call to process_dynamic_shape successful (UNEXPECTED if shapes differ and max_traces=1).")

except AssertionError as e:
    print(f"\nCaught EXPECTED AssertionError for too many traces:\n{e}")
except Exception as e:
    print(f"Caught unexpected error: {e}")

### Solution (for Exercise 7.2.1 & 7.2.2, after attempting):

In [None]:
# chex.clear_trace_counter() # Reset counter for this specific example

# # Scenario 1: Function with static argument for shape
# @functools.partial(jit, static_argnums=(1,))
# @chex.assert_max_traces(n=1)
# def process_fixed_shape_staticarg_sol(x: chex.Array, shape_tuple: tuple):
#     chex.assert_shape(x, shape_tuple)
#     return x * 2.0

# print("--- SOLUTION: Scenario 1: Static argnum, consistent shape tuple ---")
# fixed_shape_sol = (3, 4)
# input_data_s1_c1_sol = jnp.ones(fixed_shape_sol)
# input_data_s1_c2_sol = jnp.zeros(fixed_shape_sol)
# _ = process_fixed_shape_staticarg_sol(input_data_s1_c1_sol, fixed_shape_sol)
# print("First call to process_fixed_shape_staticarg_sol successful (traces).")
# _ = process_fixed_shape_staticarg_sol(input_data_s1_c2_sol, fixed_shape_sol)
# print("Second call to process_fixed_shape_staticarg_sol successful (reuses cache).")


# chex.clear_trace_counter() # Reset for this scenario

# @jit
# @chex.assert_max_traces(n=1) # SOLUTION 7.3.1
# def process_dynamic_shape_sol(x: chex.Array):
#     return x + jnp.sum(x)

# print("\n--- SOLUTION: Scenario 2: Varying input shapes ---")
# try:
#     print("Calling process_dynamic_shape_sol with (2, 2)...")
#     _ = process_dynamic_shape_sol(jnp.ones((2, 2)))
#     print("First call to process_dynamic_shape_sol successful.")

#     print("Calling process_dynamic_shape_sol with (3, 3)...")
#     _ = process_dynamic_shape_sol(jnp.ones((3, 3))) # SOLUTION 7.3.2
#     print("Second call to process_dynamic_shape_sol successful (UNEXPECTED if shapes differ and max_traces=1).")

# except AssertionError as e:
#     print(f"\nCaught EXPECTED AssertionError for too many traces:\n{e}")
# except Exception as e:
#     print(f"Caught unexpected error: {e}")

In Scenario 2, the AssertionError happens because process_dynamic_shape is JIT-compiled based on the shape of its input x. When called the second time with a different shape, JAX needs to re-trace and re-compile the function for this new shape. @chex.assert_max_traces(n=1) detects this second trace and raises an error, alerting you to a potential performance issue due to recompilation.

## 8. Monitoring with TensorBoard
TensorBoard is excellent for visualizing training metrics. The setup is similar to PyTorch.

### Exercise 8.1 - 8.3:
1. 8.1: Create a tensorboardX.SummaryWriter instance, saving logs to LOG_DIR.
2. 8.2: Inside the loop, use writer.add_scalar() to log dummy_loss and dummy_accuracy. Crucially, convert them to Python scalars using .item().
3. 8.3: After the loop, close the writer using writer.close().
4. Run the cell.
5. If you are in Colab:
 - Uncomment the lines %load_ext tensorboard and %tensorboard --logdir {LOG_DIR} at the end of the cell.
 - Run the cell again. TensorBoard should appear in the output. Navigate to the SCALARS tab.
6. If running locally:
 - Open your terminal.
 - Navigate to the directory containing the logs folder (i.e., the parent of LOG_DIR).
 - Run tensorboard --logdir logs.
 - Open the URL (usually http://localhost:6006) in your browser.
7. Explore the TensorBoard and profiler (XProf) tools

In [None]:
# !pip install -Uq tensorboardX tensorboard tensorboard_plugin_profile
!pip install -Uq tensorboardX tensorboard_plugin_profile
!pip install -U protobuf

In [None]:
# For TensorBoard
from tensorboardX import SummaryWriter
import shutil # For cleaning up log directories

In [None]:
# Clean up previous logs if any
LOG_DIR = "logs/jax_debug_run"
if shutil.os.path.exists(LOG_DIR):
    shutil.rmtree(LOG_DIR)
    print(f"Removed old log directory: {LOG_DIR}")

# Exercise 8.1: Create a SummaryWriter from tensorboardX
# Point it to the LOG_DIR defined above.
# YOUR CODE HERE
# writer = ...

# Dummy training loop
print("\nSimulating training loop...")
jax.profiler.start_trace(LOG_DIR) # Capturing trace for xprof

for epoch in range(10):
  # Simulate loss and accuracy (JAX arrays)
  dummy_loss = jnp.array(1.0 / (epoch + 1))
  dummy_accuracy = jnp.array(1.0 - dummy_loss)

  # Exercise 8.2: Log dummy_loss as 'Loss/train' and dummy_accuracy as 'Accuracy/validation'
  # Remember to use .item() to convert JAX arrays to Python scalars before logging.
  # Use 'epoch' as the global_step.
  # YOUR CODE HERE

  if (epoch + 1) % 2 == 0:
    print(f"Epoch {epoch+1}: Loss = {dummy_loss.item():.4f}, Acc = {dummy_accuracy.item():.4f}")

jax.profiler.stop_trace()
# Exercise 8.3: Close the writer
# YOUR CODE HERE

print(f"\nTensorBoard logs saved to: {LOG_DIR}")
print("To view in TensorBoard, run the following in your terminal (if local):")
print(f"tensorboard --logdir={LOG_DIR.split('/')[0]}") # Get base 'logs' dir
print("Or, if in Colab, you can use the %tensorboard magic:")
# %load_ext tensorboard
# %tensorboard --logdir {LOG_DIR}

### Solution (for Exercise 8.1-8.3, after attempting):

In [None]:
# # Clean up previous logs if any
# LOG_DIR_SOL = "logs/jax_debug_run_solution" # Use a different dir for solution
# if shutil.os.path.exists(LOG_DIR_SOL):
#     shutil.rmtree(LOG_DIR_SOL)
#     print(f"Removed old log directory: {LOG_DIR_SOL}")

# writer = SummaryWriter(LOG_DIR_SOL) # SOLUTION 8.1
# print(f"TensorBoard writer initialized. Logging to: {LOG_DIR_SOL}")

# # Dummy training loop
# print("\nSimulating training loop...")
# # Ensure the profiler plugin is included in the trace
# jax.profiler.start_trace(LOG_DIR_SOL, create_perfetto_link=False) # Capturing trace for xprof

# for epoch in range(10):
#   dummy_loss = jnp.array(1.0 / (epoch + 1))
#   dummy_loss.block_until_ready() # Ensure the array is ready
#   dummy_accuracy = jnp.array(1.0 - dummy_loss)
#   dummy_accuracy.block_until_ready() # Ensure the array is ready

#   writer.add_scalar('Loss/train', dummy_loss.item(), global_step=epoch) # SOLUTION 8.2
#   writer.add_scalar('Accuracy/validation', dummy_accuracy.item(), global_step=epoch) # SOLUTION 8.2

#   if (epoch + 1) % 2 == 0:
#     print(f"Epoch {epoch+1}: Loss = {dummy_loss.item():.4f}, Acc = {dummy_accuracy.item():.4f}")

# jax.profiler.stop_trace()
# writer.close() # SOLUTION 8.3

# print(f"\nTensorBoard logs saved to: {LOG_DIR_SOL}")
# print("To view in TensorBoard, run the following in your terminal (if local):")
# print(f"tensorboard --logdir={LOG_DIR_SOL.split('/')[0]}")
# print("Or, if in Colab, you can use the %tensorboard magic:")
# %load_ext tensorboard
# %tensorboard --logdir {LOG_DIR_SOL} # --general_plugin_dir "{LOG_DIR_SOL}/plugins"

# Profiling with XProf

Profiling is also essential for understanding and improving your code.  XProf is a great tool for profiling JAX and Flax NNX, and is compatible with TensorBoard.  We've seen the XProf profiler with TensorBoard above, but let's look at a more interesting example.  We'll download some profiling data from an MNIST model.

In [None]:
# git clone the xprof repo so we have access to the demo data there
!git clone http://github.com/openxla/xprof

# Launch TensorBoard and navigate to the Profile tab to view performance profile
%tensorboard --logdir=xprof/demo

## 9. Visualizing Data Layout: jax.debug.visualize_array_sharding

Understanding data sharding is crucial for multi-device training. `jax.debug.visualize_array_sharding` helps visualize this.

Actually demonstrating this effectively requires a multi-device setup (e.g., multiple GPUs or TPUs and a Mesh). In a standard Colab CPU/single GPU environment, arrays won't be genuinely sharded across a mesh, but we can still see how the function works by faking a multi-device environment using `chex.set_n_cpu_devices`, which we did at the beginning of this Colab.

### Exercise 9.1:
1. Run the cell below.
2. Observe the output of `jax.debug.visualize_array_sharding`. Even on a single device, it will print information about the array's (lack of) sharding.
3. Think: If you had a Mesh of 4 devices arranged in a 2x2 grid (`Mesh(devices, ('dp', 'mp'))`) and an array arr of shape (8, 1024), how might you define a PartitionSpec to shard arr across data parallelism (dp) for the first dimension and model parallelism (mp) for the second? What would you expect `visualize_array_sharding(arr)` to show?

In [None]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax import jit, grad, vmap

try:
  if len(jax.devices()) >= 2:
      device_mesh = mesh_utils.create_device_mesh((len(jax.devices()),)) # Use all available devices
      mesh = Mesh(devices=device_mesh, axis_names=('data',))
      print(f"Created a mesh with shape: {mesh.devices.shape} and names: {mesh.axis_names}")
  else:
      print("Not enough devices to create a meaningful mesh for sharding demo. Will run on single device.")
      mesh = None
except Exception as e:
  print(f"Could not create mesh (likely on CPU Colab or single GPU): {e}")
  mesh = None


@jit
def sharded_computation_demo(x_unsharded):
  # In a real scenario, x would be sharded before being passed or sharded inside
  # For this demo, we'll just visualize the unsharded array as if it were sharded

  print("--- Input 'Sharding' (on single device, so not truly sharded) ---")
  jax.debug.visualize_array_sharding(x_unsharded)

  y = x_unsharded * 2.0

  # If 'x_unsharded' had sharding, 'y' would typically inherit it or have a related one.
  print("--- Output 'Sharding' (on single device) ---")
  jax.debug.visualize_array_sharding(y)
  return y

an_array = jnp.arange(8.0)

print(f"Original array: {an_array}")

# If we had a mesh, we could try to shard it:
if mesh:
  # Shard along the first axis ('data')
  sharding_spec = NamedSharding(mesh, PartitionSpec('data',))
  an_array_sharded = jax.device_put(an_array, sharding_spec)
  print(f"Array sharding: {an_array_sharded.sharding}")
  output_sharded = sharded_computation_demo(an_array_sharded)
else:
  print("No mesh, running unsharded demo.")
  output_unsharded = sharded_computation_demo(an_array) # Run with the original JIT

# Simplified version for Colab (no actual sharding applied)
print("\n--- Running visualization on a single device (no actual sharding) ---")
output_unsharded = sharded_computation_demo(an_array)
print(f"Output: {output_unsharded}")

### Answer (for Conceptual Exercise 9.1):
- You might define P = PartitionSpec('dp', 'mp').
- jax.debug.visualize_array_sharding(arr) would then print a diagram showing how the 8 rows are split over the 'dp' axis (e.g., 4 rows per device slice along 'dp') and the 1024 columns are split over the 'mp' axis (e.g., 512 columns per device slice along 'mp'). Each device in the 2x2 mesh would hold a (4, 512) slice of the original array.

## Conclusion & Key Takeaways
You've now practiced with several key JAX and Flax NNX debugging tools!
- jax.debug.print() & jax.debug.breakpoint(): Your go-to tools for inspecting values inside JITted code.
- jax.disable_jit(): The "escape hatch" to use standard Python debuggers (pdb, IDEs) at the cost of performance.
- jax_debug_nans: Invaluable for automatically finding the source of NaNs.
- nnx.display(): Essential for understanding your NNX model's architecture and state.
- nnx.sow(): Useful for capturing intermediate activations without altering function signatures.
- Chex assertions (assert_shape, assert_tree_all_finite, assert_max_traces): Build robust and performant code by catching errors early and detecting recompilations.
- TensorBoard: Standard for monitoring training, works seamlessly with JAX.
Debugging in JAX's compiled world requires adapting your PyTorch habits, but with these tools, you're well-equipped to tackle issues effectively!

Please send us feedback at https://goo.gle/jax-training-feedback