Skip to content

Commit

Permalink
jaxpr staging: only one tracer per jaxpr variable
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jan 25, 2022
1 parent 04d8b35 commit 98816f3
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,20 +1193,29 @@ def _assert_live(self) -> None:
if not self._trace.main.jaxpr_stack: # type: ignore
raise core.escaped_tracer_error(self, None)

TracerId = int
ConstId = int
class JaxprStackFrame:
__slots__ = ['gensym', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns', 'invars']
gensym: Callable[[AbstractValue], Var]
tracer_to_var: Dict[TracerId, Var]
constid_to_tracer: Dict[ConstId, Tracer]
constvar_to_val: Dict[Var, Any]
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
eqns: List[JaxprEqn]
invars: List[Var]

def __init__(self):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_var = {}
self.constid_to_tracer = {}
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []

def to_jaxpr(self, out_tracers):
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns)
Expand Down Expand Up @@ -1322,12 +1331,15 @@ def new_arg(self, aval):
self.frame.invars.append(var)
return tracer

def new_const(self, val):
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(aval, val)
self.frame.constvar_to_val[var] = val
def new_const(self, c):
tracer = self.frame.constid_to_tracer.get(id(c))
if tracer is None:
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
self.frame.constid_to_tracer[id(c)] = tracer
self.frame.constvar_to_val[var] = c
return tracer

pure = lift = sublift = new_const
Expand All @@ -1345,12 +1357,6 @@ def makevar(self, tracer):
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
return var

def getconstvar(self, aval, c):
var = self.frame.constid_to_var.get(id(c))
if var is None:
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(aval)
return var

def instantiate_const(self, val):
if (isinstance(val, Tracer) and val._trace.main is self.main
and val._trace.sublevel == self.sublevel):
Expand Down Expand Up @@ -1648,7 +1654,6 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial


AvalId = int
TracerId = int
def _avals_to_tracers(
trace: DynamicJaxprTrace, in_avals: Sequence[AbstractValue]
) -> Sequence[Tracer]:
Expand Down

0 comments on commit 98816f3

Please sign in to comment.