Skip to content

Commit

Permalink
add a memory leak test for jit jaxpr construction
Browse files Browse the repository at this point in the history
Tweak implementation for `_inline_literals` not to include a class
defined in a function, since that seemed to cause leaking!
  • Loading branch information
mattjj committed Mar 17, 2021
1 parent 52bd306 commit 0181d03
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
17 changes: 7 additions & 10 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,11 +986,8 @@ def find_progenitors(self, tracer):
def _inline_literals(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
class var(dict):
def __missing__(self, v):
new_v = self[v] = newvar(v.aval)
return new_v
var = var()
newvars = {}
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))

def lit(var: core.Var) -> Optional[Any]:
val = consts.get(var)
Expand All @@ -1000,14 +997,14 @@ def lit(var: core.Var) -> Optional[Any]:
return None

used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_invars = [var[v] for v in jaxpr.invars]
new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
[var[v] if v in used else dropvar for v in eqn.outvars],
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = [new_jaxpr_eqn([lit(v) or var(v) for v in eqn.invars],
[var(v) if v in used else dropvar for v in eqn.outvars],
eqn.primitive, eqn.params, eqn.source_info)
for eqn in jaxpr.eqns]
new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals

Expand Down
19 changes: 18 additions & 1 deletion tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,24 @@ def f(x):
try:
fn(params)
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)

def test_reference_cycles_jit(self):
gc.collect()

def f(x):
return x.sum()

fn = jit(f)
params = jnp.zeros([])

debug = gc.get_debug()
try:
fn(params).block_until_ready()
gc.set_debug(gc.DEBUG_SAVEALL)
self.assertEqual(gc.collect(), 0, msg=str(gc.garbage))
finally:
gc.set_debug(debug)

Expand Down

0 comments on commit 0181d03

Please sign in to comment.