diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index dc612dbe391e..cf09e27ed0d1 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -18,7 +18,7 @@ import inspect import operator from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Type import warnings import numpy as np @@ -675,3 +675,27 @@ def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None, assert dbg.result_paths is None res_paths_ = HashableFunction(res_paths, closure=()) return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_)) + + +def hoist_obj_attrs(f, flat_args): + idxs, objs, flat_args_ = [], [], [] + for i, x in enumerate(flat_args): + if type(x) in _class_with_attrs: + objs.append(_HashableByObjectId(x)) + else: + idxs.append(i) + flat_args_.append(x) + return _argnums_partial(f, tuple(idxs), tuple(objs)), flat_args_ + +class _HashableByObjectId: + __slots__ = ['val'] + def __init__(self, val): + self.val = val + def __hash__(self): + return id(self.val) + def __eq__(self, other): + return self.val is other.val + +def register_class_with_attrs(t: Type) -> None: + _class_with_attrs.add(t) +_class_with_attrs: set[Type] = set() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index de6195636b03..a08a8a8a7f5f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -44,7 +44,8 @@ from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, shaped_abstractify, check_callable, resolve_argnums, - argnames_partial_except, debug_info, result_paths, jaxpr_debug_info) + argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, + hoist_obj_attrs) from jax._src.errors import JAXTypeError from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec @@ -426,13 +427,13 @@ def common_infer_params(pjit_info_args, *args, **kwargs): dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames) f = lu.wrap_init(fun) f, res_paths = result_paths(f) - f, dyn_args = argnums_partial_except(f, static_argnums, args, - allow_invalid=True) + f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) del args f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) + flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) if (donate_argnums or donate_argnames) and not config.debug_nans.value: donated_invars = donation_vector( @@ -918,7 +919,9 @@ def _process_in_axis_resources(in_shardings_thunk, in_layouts_thunk, in_avals, in_layouts_flat = flatten_axis_resources( "pjit in_layouts", in_tree, in_layouts, tupled_args=True) - if not config.dynamic_shapes.value: + # TODO(dougalm,mattjj): enable debug info with attrs_tracked + attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals) + if not config.dynamic_shapes.value and not attrs_tracked: pjit_check_aval_sharding(in_shardings_flat, in_avals, None if debug_info is None else debug_info.arg_names, "pjit arguments", allow_uneven_sharding=False) @@ -936,7 +939,8 @@ def explain_tracing_cache_miss( def unpack(key): transforms, (), _, (in_type, debug_info, _, inline), *_, ctx = key - (_, (in_tree,)), (_, ()) = transforms + # TODO(dougalm,mattjj): enable cache miss explanation with attrs + _, (_, (in_tree,)), (_, ()) = transforms return in_tree, in_type, debug_info, inline.val, ctx in_tree, in_type, debug_info, inline, ctx = unpack(key) if inline: return diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 1f28f1f55a3b..ac24bd72f04a 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -17,6 +17,7 @@ from typing import Any from jax._src import core +from jax._src import api_util from jax._src import linear_util as lu from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad @@ -26,6 +27,8 @@ JaxVal = Any +register = api_util.register_class_with_attrs + class GetAttrPrimitive(core.Primitive): def bind_with_trace(self, trace, args, params): () = args diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 17dd5cbde757..339af904c832 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -39,6 +39,8 @@ class Thing: x: float +attrs.register(Thing) + class AttrsTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -208,6 +210,22 @@ def body(x, __): jax.grad(double_it_10)(1.0) self.assertAllClose(thing.x, 1024., check_dtypes=False) + def test_arg_to_jit(self): + thing = Thing(1.0) + count = 0 + + @jax.jit + def f(obj, x): + nonlocal count + count += 1 + jax_setattr(obj, 'x', x) + + f(thing, 2.0) # don't crash! + self.assertAllClose(thing.x, 2.0, check_dtypes=False) + f(thing, 3.0) + self.assertAllClose(thing.x, 3.0, check_dtypes=False) + self.assertEqual(count, 1) + class AttrsJVPTest(jtu.JaxTestCase):