In [1]:
import jax, jax.numpy as jp
import flax.struct
import flax.linen as nn
import math

# jax.config.update('jax_platform_name', 'cpu')


class Network(nn.Module):
    @nn.compact
    def __call__(self, x: jp.ndarray):
        x = nn.Dense(512)(x)
        for i in range(int(math.log2(x.shape[1])) - 2):
            x = nn.Conv(512 // 2 ** (i + 1), (3, 3), (1, 1), padding='SAME')(x)
            x = nn.Conv(512 // 2 ** (i + 1), (3, 3), (2, 2), padding='VALID')(x)
            x = nn.LayerNorm()(x)
            x = nn.gelu(x)
        
        x = x.reshape(-1, 1)
        x = nn.Dense(1)(x)
        x = nn.tanh(x)
        return x
    

model = Network()
rng = jax.random.PRNGKey(0)
out, param = model.init_with_output(rng, jp.ones((1, 128, 128, 3)))
print(out.shape)

(144, 1)


In [2]:
@jax.jit
def step(params, x):
    return model.apply(params, x)

x = jax.random.uniform(rng, (1, 128, 128, 3))
%timeit -n 1 -r 1 step(param, x)

483 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [3]:
%timeit -n 1 -r 1 step(param, x)

477 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [4]:
from functools import partial
from typing import Any, Dict
nontypree_filed = partial(flax.struct.field, pytree_node=False)

class Tmp(flax.struct.PyTreeNode):
    log: Dict[str, Any] = nontypree_filed()

    @jax.jit
    def step(self, params, x):
        out = model.apply(params, x)
        self.log['out'].append(out.mean())
        return out

rng, _ = jax.random.split(rng)
x = jax.random.uniform(rng, (1, 128, 128, 3))
log = {'out': []}
tmp = Tmp(log=log)
%timeit -n 1 -r 1 tmp.step(param, x)


154 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [5]:
%timeit -n 1 -r 1 tmp.step(param, x)

490 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [6]:
tmp.log

{'out': [Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>]}

In [13]:
tmp.log['out'][0].max()

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was step at /tmp/ipykernel_125546/2928496050.py:8 traced for jit.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_125546/2928496050.py:11:31 (Tmp.step). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:198:11 (_run_module_as_main)
<frozen runpy>:88:4 (_run_code)
/tmp/ipykernel_125546/2928496050.py:18 (<module>)
<magic-timeit>:1 (inner)
/tmp/ipykernel_125546/2928496050.py:11:31 (Tmp.step)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [15]:
tmp = tmp.replace(x=jax.random.uniform(jax.random.split(rng)[0], (1, 64, 64, 3)))
%timeit -n 1 -r 1 tmp.step()

130 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [17]:
tmp = tmp.replace(dummpy=False)
%timeit -n 1 -r 1 tmp.step()

22.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [18]:
param['params']['Conv_0']['kernel'] += 1
tmp = tmp.replace(params=param)
%timeit -n 1 -r 1 tmp.step()

19.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
