# Debugging Re-JIT Bugs

Occasionally, you may notice a function is being re-jitted against your intentions. Oftentimes, this is due to a PyTree shape or structure mismatch; however, it can be challenging to diagnose exactly which component is to blame (especially if you have nested PyTrees!).

This notebook introduces a decorator that helps diagnose re-jitting bugs: `@debug_rejit`.

By simply adding this decorator to the problematic function (either above the `@jit` decorator OR replacing it), the wrapped function will print diagnostics when it detects a change in the PyTree shape/structure from one function call to the next. Please see the real-life example code cells below to see it in action!

__Note: these debugging utility functions can be found in `ssm.utils`__

## Quickstart

```python
@debug_rejit    # <--- ADD DEBUG DECORATOR HERE
@jax
def f(x)
    # problematic function is being re-jitted!
    ...
    return y

...

# run f(x) as you normally would
```

In [1]:
from jax import tree_util
import inspect
import copy

# terminal color macros
CRED = '\033[91m'
CEND = '\033[0m'

def check_pytree_structure_match(obj_a, obj_b, mode="input", sig=None):
    """Checks whether pytrees A and B have the same structure.
    Used for debugging re-jit problems (see debug_rejit decorator).

    Args:
        obj_a (jaxlib.xla_extension.PyTreeDef): pytree obj A (prev)
        obj_b (jaxlib.xla_extension.PyTreeDef): pytree obj B (curr)
        mode (str, optional): "input" or "output". Defaults to "input".
        sig (inspect.FullArgSpec, optional): optional function signature.
            Used for better debug description. Defaults to None.
    """
    struct_a = tree_util.tree_structure(obj_a)
    struct_b = tree_util.tree_structure(obj_b)
    if struct_a != struct_b:
        for i, (a, b) in enumerate(zip(struct_a.children(), struct_b.children())):
            if a != b:
                print(f"{CRED}[[structure mismatch found for {mode} at index {i}"\
                        f"{f' (arg={sig.args[i]})' if sig is not None else ''}]]{CEND}")
                print(f"prev={a}\ncurr={b}")
        
def check_pytree_shape_match(obj_a, obj_b, mode="input", sig=None):
    """Checks whether pytrees A and B have the same leaf shapes.
    Used for debugging re-jit problems (see debug_rejit decorator).

    Args:
        obj_a (jaxlib.xla_extension.PyTreeDef): pytree obj A (prev)
        obj_b (jaxlib.xla_extension.PyTreeDef): pytree obj B (curr)
        mode (str, optional): "input" or "output". Defaults to "input".
        sig (inspect.FullArgSpec, optional): optional function signature.
            Used for better debug description. Defaults to None.
    """
    shape_a = [x.shape for x in tree_util.tree_leaves(obj_a)]
    shape_b = [x.shape for x in tree_util.tree_leaves(obj_b)]
    if shape_a != shape_b:
        for i, (a, b) in enumerate(zip(shape_a, shape_b)):
            if a != b:
                print(f"{CRED}[[shape mismatch found for {mode} at index {i}"\
                        f"{f' (arg={sig.args[i]})' if sig is not None else ''}]]{CEND}")
                print(f"prev={a}\ncurr={b}")

def check_pytree_match(obj_a,
                       obj_b,
                       mode: str="input",
                       sig: inspect.FullArgSpec=None):
    """Checks whether pytrees A and B are the same by checking shape AND structure.
    Used for debugging re-jit problems (see debug_rejit decorator).

    Args:
        obj_a (jaxlib.xla_extension.PyTreeDef): pytree structure A (prev)
        obj_b (jaxlib.xla_extension.PyTreeDef): pytree structure B (curr)
        mode (str, optional): "input" or "output". Defaults to "input".
        sig (inspect.FullArgSpec, optional): optional function signature.
            Used for better debug description. Defaults to None.
    """
    check_pytree_shape_match(obj_a, obj_b, mode, sig)
    check_pytree_structure_match(obj_a, obj_b, mode, sig)

def debug_rejit(func):
    """Decorator to debug re-jitting errors.
    
    Checks if input and output pytrees are consistent across multiple 
    calls to func (else: func will need to be re-compiled).
    
    Example:
    
        @debug_rejit
        @jit
        def fn(inputs):
            return outputs

        ==> will print out useful description when input/output
            pytrees mismatch (i.e. when fn will re-jit)
    """
    def wrapper(*args, **kwargs):
        
        # get tree structure for args and kwargs
        inputs = list(args) + list(kwargs.values())
        if wrapper.prev_in is None:
            wrapper.prev_in = copy.deepcopy(inputs)
        
        # run the function
        outputs = func(*args, **kwargs)

        # get tree structure for output (this works for tuple outputs too)
        if wrapper.prev_out is None:
            wrapper.prev_out = copy.deepcopy(outputs)

        # check whether the input and output structures match w/ prev fn call
        check_pytree_match(inputs, wrapper.prev_in, mode="input", sig=wrapper.sig)
        check_pytree_match(outputs, wrapper.prev_out, mode="output")
        
        # store for next fn call
        wrapper.prev_in = inputs
        wrapper.prev_out = outputs
        
        # return the output
        return outputs
    
    wrapper.sig = inspect.getfullargspec(func)
    wrapper.prev_in = None
    wrapper.prev_out = None
    return wrapper

# Examples

## 1) Shape mismatch

In [2]:
import jax.random as jr
import jax.numpy as np
from jax import jit

@debug_rejit
@jit
def f(x):
    print("jit compiling!")
    x += 1
    return x

x = np.array([0, 1, 2])
x = f(x)  # NOTE: compiles for input shape (3,) and output shape (3,)
x = f(x)  # runs the compiled function

x = np.array([0, 1])
x = f(x)  # NOTE: has to re-compile for input shape (2,) and output shape (2,)
x = f(x)  # runs the new-compiled function



jit compiling!
jit compiling!
[91m[[shape mismatch found for input at index 0 (arg=x)]][0m
prev=(2,)
curr=(3,)
[91m[[shape mismatch found for output at index 0]][0m
prev=(2,)
curr=(3,)


## 2) PyTree Structure Mismatch

In [10]:
import tensorflow_probability.substrates.jax as tfp

# initialize distribution
my_dist = tfp.distributions.Categorical(logits=np.ones((5,)))  # NOTE: parameterize using logits

@debug_rejit
@jit
def m_step(my_dist, rng):
    print("jit compiling!")
    new_probs = jr.normal(rng, (5,))
    return tfp.distributions.Categorical(probs=new_probs)  # NOTE: parameterize using probs!
m_step = jit(m_step)

rng = jr.PRNGKey(0)
num_updates = 3
for _ in range(num_updates):
    this_rng, rng = jr.split(rng, 2)
    my_dist = m_step(my_dist, this_rng)

jit compiling!
jit compiling!
[91m[[structure mismatch found for input at index 0 (arg=my_dist)]][0m
prev=PyTreeDef(CustomNode(<class 'tensorflow_probability.substrates.jax.distributions.categorical.Categorical'>[(('logits', 'probs'), {'dtype': <class 'jax._src.numpy.lax_numpy.int32'>, 'validate_args': False, 'allow_nan_stats': True, 'name': 'Categorical'})], [None, *]))
curr=PyTreeDef(CustomNode(<class 'tensorflow_probability.substrates.jax.distributions.categorical.Categorical'>[(('logits', 'probs'), {'dtype': <class 'jax._src.numpy.lax_numpy.int32'>, 'validate_args': False, 'allow_nan_stats': True, 'name': 'Categorical'})], [*, None]))
