# Debugging Re-JIT Bugs

Occasionally, you may notice a function is being re-jitted against your intentions. Oftentimes, this is due to a PyTree shape or structure mismatch; however, it can be challenging to diagnose exactly which component is to blame (especially if you have nested PyTrees!).

This notebook introduces a decorator that helps diagnose re-jitting bugs: `@debug_rejit`.

By simply adding this decorator to the problematic function (either above the `@jit` decorator OR replacing it), the wrapped function will print diagnostics when it detects a change in the PyTree shape/structure from one function call to the next. Please see the real-life example code cells below to see it in action!

__Note: these debugging utility functions can be found in `ssm.utils`__

## Quickstart

```python
@debug_rejit    # <--- ADD DEBUG DECORATOR HERE
@jax
def f(x)
    # problematic function is being re-jitted!
    ...
    return y

...

# run f(x) as you normally would
```

In [1]:
from ssm.utils import debug_rejit

# Examples

## 1) Shape mismatch

In [2]:
import jax.random as jr
import jax.numpy as np
from jax import jit

@debug_rejit
@jit
def f(x):
    print("jit compiling!")
    x += 1
    return x

x = np.array([0, 1, 2])
x = f(x)  # NOTE: compiles for input shape (3,) and output shape (3,)
x = f(x)  # runs the compiled function

x = np.array([0, 1])
x = f(x)  # NOTE: has to re-compile for input shape (2,) and output shape (2,)
x = f(x)  # runs the new-compiled function



jit compiling!
jit compiling!
[91m[[PyTree Leaf Shape mismatch found for input at index 0]]
prev=(2,)
curr=(3,) [0m
[91m[input pytree leaf [0]]
prev= DeviceArray([0, 1], dtype=int32)
curr= DeviceArray([1, 2, 3], dtype=int32) [0m
[91m[[PyTree Leaf Shape mismatch found for output at index 0]]
prev=(2,)
curr=(3,) [0m
[91m[output pytree leaf [0]]
prev= DeviceArray([1, 2], dtype=int32)
curr= DeviceArray([2, 3, 4], dtype=int32) [0m


## 2) PyTree Structure Mismatch

In [3]:
import tensorflow_probability.substrates.jax as tfp

# initialize distribution
my_dist = tfp.distributions.Categorical(logits=np.ones((5,)))  # NOTE: parameterize using logits

@debug_rejit
@jit
def m_step(my_dist, rng):
    print("jit compiling!")
    new_probs = jr.normal(rng, (5,))
    return tfp.distributions.Categorical(probs=new_probs)  # NOTE: parameterize using probs!
m_step = jit(m_step)

rng = jr.PRNGKey(0)
num_updates = 3
for _ in range(num_updates):
    this_rng, rng = jr.split(rng, 2)
    my_dist = m_step(my_dist, this_rng)

jit compiling!
jit compiling!
[91m[[PyTreeDef Structure mismatch found for input at index 0 (arg=my_dist)]]
prev=PyTreeDef(CustomNode(<class 'tensorflow_probability.substrates.jax.distributions.categorical.Categorical'>[(('logits', 'probs'), {'dtype': <class 'jax._src.numpy.lax_numpy.int32'>, 'validate_args': False, 'allow_nan_stats': True, 'name': 'Categorical'})], [None, *]))
curr=PyTreeDef(CustomNode(<class 'tensorflow_probability.substrates.jax.distributions.categorical.Categorical'>[(('logits', 'probs'), {'dtype': <class 'jax._src.numpy.lax_numpy.int32'>, 'validate_args': False, 'allow_nan_stats': True, 'name': 'Categorical'})], [*, None])) [0m
[91m[input pytree structure [0]]
prev= Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
curr= Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)> [0m


# Weak Type Promotion Issue

In [4]:
from jax import jit
from jax.flatten_util import ravel_pytree
from ssm.utils import debug_rejit
from jax import tree_util

print("Arrays are equal? ", np.array(5) == np.array(5, dtype="float32"))

print("\n----- Under the Hood ------")
print(repr(np.array(5)))
print(repr(np.array(5, dtype=np.float32)))

# doesn't rejit
print("\n----- Simple Jitted Function (1x jit) ------")
@jit
def f(x):
    print("jit!")
    return x

x = np.array(5)
x = f(x)
x = f(x)
x = f(x)

# ravel / unravel seems to mess up weak typed arrays 
print("\n----- Unravel Jitted Function (2x jit) ------")
@jit
def f(x):
    print("jit!")
    flat_x, unflatten_fn = ravel_pytree(x)
    return unflatten_fn(flat_x)

x = np.array(5)
x = f(x)
x = f(x)
x = f(x)

Arrays are equal?  True

----- Under the Hood ------
DeviceArray(5, dtype=int32, weak_type=True)
DeviceArray(5., dtype=float32)

----- Simple Jitted Function (1x jit) ------
jit!

----- Unravel Jitted Function (2x jit) ------
jit!
jit!


In [5]:
# let's use our debugger to find what's going wrong!
@debug_rejit
@jit
def f(x):
    print("jit!")
    flat_x, unflatten_fn = ravel_pytree(x)
    return unflatten_fn(flat_x)

x = np.array(5)
x = f(x)
x = f(x)
x = f(x)

jit!
jit!
[91m[[Pytree Leaf Device Array Weak Type mismatch found for input at index 0]]
prev=False
curr=True [0m
[91m[input pytree leaf [0]]
prev= DeviceArray(5, dtype=int32)
curr= DeviceArray(5, dtype=int32, weak_type=True) [0m
