Skip to content

Commit

Permalink
[attrs] allow passing a jax-attrs object to jit functions
Browse files Browse the repository at this point in the history
currently we don't get any interesting cache hits; only on object identity
match
  • Loading branch information
mattjj committed Feb 14, 2024
1 parent 2adefe9 commit a45cc43
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
26 changes: 25 additions & 1 deletion jax/_src/api_util.py
Expand Up @@ -18,7 +18,7 @@
import inspect
import operator
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Type
import warnings

import numpy as np
Expand Down Expand Up @@ -675,3 +675,27 @@ def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))


def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
if type(x) in _class_with_attrs:
objs.append(_HashableByObjectId(x))
else:
idxs.append(i)
flat_args_.append(x)
return _argnums_partial(f, tuple(idxs), tuple(objs)), flat_args_

class _HashableByObjectId:
__slots__ = ['val']
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return self.val is other.val

def register_class_with_attrs(t: Type) -> None:
_class_with_attrs.add(t)
_class_with_attrs: set[Type] = set()
14 changes: 9 additions & 5 deletions jax/_src/pjit.py
Expand Up @@ -44,7 +44,8 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info)
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
hoist_obj_attrs)
from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -426,13 +427,13 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, static_argnums, args,
allow_invalid=True)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
del args

f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args)

if (donate_argnums or donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector(
Expand Down Expand Up @@ -918,7 +919,9 @@ def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals,
in_layouts_flat = flatten_axis_resources(
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)

if not config.dynamic_shapes.value:
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
if not config.dynamic_shapes.value and not attrs_tracked:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names,
"pjit arguments", allow_uneven_sharding=False)
Expand All @@ -936,7 +939,8 @@ def explain_tracing_cache_miss(

def unpack(key):
transforms, (), _, (in_type, debug_info, _, inline), *_, ctx = key
(_, (in_tree,)), (_, ()) = transforms
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
_, (_, (in_tree,)), (_, ()) = transforms
return in_tree, in_type, debug_info, inline.val, ctx
in_tree, in_type, debug_info, inline, ctx = unpack(key)
if inline: return
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/attrs.py
Expand Up @@ -17,6 +17,7 @@
from typing import Any

from jax._src import core
from jax._src import api_util
from jax._src import linear_util as lu
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
Expand All @@ -26,6 +27,8 @@

JaxVal = Any

register = api_util.register_class_with_attrs

class GetAttrPrimitive(core.Primitive):
def bind_with_trace(self, trace, args, params):
() = args
Expand Down
18 changes: 18 additions & 0 deletions tests/attrs_test.py
Expand Up @@ -39,6 +39,8 @@
class Thing:
x: float

attrs.register(Thing)

class AttrsTest(jtu.JaxTestCase):

@parameterized.parameters([True, False])
Expand Down Expand Up @@ -208,6 +210,22 @@ def body(x, __):
jax.grad(double_it_10)(1.0)
self.assertAllClose(thing.x, 1024., check_dtypes=False)

def test_arg_to_jit(self):
thing = Thing(1.0)
count = 0

@jax.jit
def f(obj, x):
nonlocal count
count += 1
jax_setattr(obj, 'x', x)

f(thing, 2.0) # don't crash!
self.assertAllClose(thing.x, 2.0, check_dtypes=False)
f(thing, 3.0)
self.assertAllClose(thing.x, 3.0, check_dtypes=False)
self.assertEqual(count, 1)


class AttrsJVPTest(jtu.JaxTestCase):

Expand Down

0 comments on commit a45cc43

Please sign in to comment.