# Repro

In [1]:
from collections.abc import Mapping

import json
import jax
import jax.numpy as np
from jax import tree_util, grad, jit, vmap
import haiku as hk

In [2]:
def tree_isnan(tree):
    """Returns true if any of the tree leaves include a NaN."""
    return any([np.any(np.isnan(leaf)) for leaf in tree_util.tree_leaves(tree)])

def tree_to_numpy(tree):
    """turn all lists in tree into numpy arrays"""
    out = {}
    for k, v in tree.items():
        if isinstance(v, Mapping):
            out[k] = tree_to_numpy(v)
        else:
            out[k] = np.array(v)
    return out

In [3]:
def mlp_fn(x):
    lin = hk.nets.MLP(output_sizes=[32, 32, 1],
                      w_init=hk.initializers.VarianceScaling(scale=2.0),
                      activation=jax.nn.swish,
                      activate_final=False,
                      name="MLP")
    return lin(x)
mlp = hk.transform(mlp_fn)

In [5]:
samples = np.array([[-2.8753953],
             [-4.018667 ],
             [-3.0718434],
             [-3.9645574],
             [ 3.7780683],
             [-2.2248127],
             [ 3.4319901],
             [-2.9572632]], dtype=np.float32)

with open("./net_params.json", "r") as file:
    params_dict = json.load(file)
params = tree_to_numpy(params_dict)

def l2_norm(params, samples):
    def f(x): return mlp.apply(params, None, x)
    def fnorm(x): return np.linalg.norm(f(x))
    return np.mean(vmap(fnorm)(samples))

g     =     grad(l2_norm)(params, samples)
g_jit = jit(grad(l2_norm))(params, samples)

print("Is gradient NaN when not using jit?", tree_isnan(g)) # False
print("Is gradient NaN when using jit?", tree_isnan(g_jit)) # True

Is gradient NaN when not using jit? False
Is gradient NaN when using jit? True
