diff --git a/jax/core.py b/jax/core.py index 218a8b25fa80..8845a7c9a0f8 100644 --- a/jax/core.py +++ b/jax/core.py @@ -15,8 +15,6 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 -from __future__ import annotations - from jax._src.core import ( AbstractToken as AbstractToken, AbstractValue as AbstractValue, @@ -27,7 +25,6 @@ ConcreteArray as ConcreteArray, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, - DimSize as DimSize, DropVar as DropVar, Effect as Effect, Effects as Effects, @@ -50,7 +47,6 @@ OutputType as OutputType, ParamDict as ParamDict, Primitive as Primitive, - Shape as Shape, ShapedArray as ShapedArray, Sublevel as Sublevel, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, @@ -60,15 +56,11 @@ TraceStack as TraceStack, TraceState as TraceState, Tracer as Tracer, - TracerArrayConversionError as TracerArrayConversionError, - TracerIntegerConversionError as TracerIntegerConversionError, - UnexpectedTracerError as UnexpectedTracerError, UnshapedArray as UnshapedArray, Value as Value, Var as Var, abstract_token as abstract_token, apply_todos as apply_todos, - as_hashable_function as as_hashable_function, as_named_shape as as_named_shape, aval_mapping_handlers as aval_mapping_handlers, axis_frame as axis_frame, @@ -82,7 +74,6 @@ check_type as check_type, check_valid_jaxtype as check_valid_jaxtype, closed_call_p as closed_call_p, - collections as collections, concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, concretization_function_error as concretization_function_error, @@ -92,7 +83,6 @@ definitely_equal as definitely_equal, # TODO(necula): remove this API dimension_as_value as dimension_as_value, # TODO(necula): remove this API do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr, - dtypes as dtypes, ensure_compile_time_eval as ensure_compile_time_eval, escaped_tracer_error as escaped_tracer_error, eval_context as eval_context, @@ -114,13 +104,10 @@ lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, - lu as lu, - map as map, map_bind as map_bind, map_bind_with_continuation as map_bind_with_continuation, mapped_aval as mapped_aval, maybe_find_leaked_tracers as maybe_find_leaked_tracers, - namedtuple as namedtuple, new_base_main as new_base_main, new_jaxpr_eqn as new_jaxpr_eqn, new_main as new_main, @@ -128,8 +115,6 @@ no_axis_name as no_axis_name, no_effects as no_effects, outfeed_primitives as outfeed_primitives, - partial as partial, - pp as pp, pp_aval as pp_aval, pp_eqn as pp_eqn, pp_eqn_rules as pp_eqn_rules, @@ -150,11 +135,7 @@ raise_as_much_as_possible as raise_as_much_as_possible, raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, - ref as ref, reset_trace_state as reset_trace_state, - safe_map as safe_map, - safe_zip as safe_zip, - source_info_util as source_info_util, stash_axis_env as stash_axis_env, str_eqn_compact as str_eqn_compact, subjaxprs as subjaxprs, @@ -165,12 +146,8 @@ substitute_vars_in_output_ty as substitute_vars_in_output_ty, thread_local_state as thread_local_state, token as token, - total_ordering as total_ordering, trace_state_clean as trace_state_clean, - traceback_util as traceback_util, traverse_jaxpr_params as traverse_jaxpr_params, - tuple_delete as tuple_delete, - tuple_insert as tuple_insert, typecheck as typecheck, typecompat as typecompat, typematch as typematch, @@ -178,7 +155,130 @@ used_axis_names as used_axis_names, used_axis_names_jaxpr as used_axis_names_jaxpr, valid_jaxtype as valid_jaxtype, - zip as zip, ) symbolic_equal_dim = definitely_equal # TODO(necula): remove this API + +from jax._src import core as _src_core +_deprecations = { + # Added Oct 11, 2023: + "DimSize": ( + "jax.core.DimSize is deprecated. Use DimSize = int | Any.", + _src_core.DimSize, + ), + "Shape": ( + "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", + _src_core.Shape, + ), + "TracerArrayConversionError": ( + "jax.core.TracerArrayConversionError is deprecated. Use jax.errors.TracerArrayConversionError", + _src_core.TracerArrayConversionError, + ), + "TracerIntegerConversionError": ( + "jax.core.TracerIntegerConversionError is deprecated. Use jax.errors.TracerIntegerConversionError", + _src_core.TracerIntegerConversionError, + ), + "UnexpectedTracerError": ( + "jax.core.UnexpectedTracerError is deprecated. Use jax.errors.UnexpectedTracerError", + _src_core.UnexpectedTracerError, + ), + "as_hashable_function": ( + "jax.core.as_hashable_function is deprecated. Use jax.util.as_hashable_function directly.", + _src_core.as_hashable_function, + ), + "collections": ( + "jax.core.collections is deprecated. Use the collections module directly.", + _src_core.collections, + ), + "dtypes": ( + "jax.core.dtypes is deprecated. Use jax.dtypes directly.", + _src_core.dtypes, + ), + "lu": ( + "jax.core.lu is deprecated. Use lu = jax.extend.linear_util", + _src_core.lu, + ), + "map": ( + "jax.core.map is deprecated. Use the built-in map function.", + _src_core.map, + ), + "namedtuple": ( + "jax.core.namedtuple is deprecated. Use collections.namedtuple directly.", + _src_core.namedtuple, + ), + "partial": ( + "jax.core.partial is deprecated. Use functools.partial directly.", + _src_core.partial, + ), + "pp": ( + "jax.core.pp is deprecated. jax._src.pretty_printer is a non-public API.", + _src_core.pp, + ), + "ref": ( + "jax.core.ref is deprecated. Use weakref.ref directly.", + _src_core.ref, + ), + "safe_map": ( + "jax.core.safe_map is deprecated. Use jax.util.safe_map directly.", + _src_core.safe_map, + ), + "safe_zip": ( + "jax.core.safe_zip is deprecated. Use jax.util.safe_zip directly.", + _src_core.safe_zip, + ), + "source_info_util": ( + "jax.core.source_info_util is deprecated. jax._src.source_info_util is a non-public API.", + _src_core.source_info_util, + ), + "total_ordering": ( + "jax.core.total_ordering is deprecated. Use functools.total_ordering directly.", + _src_core.total_ordering, + ), + "traceback_util": ( + "jax.core.traceback_util is deprecated. jax._src.traceback_util is a non-public API.", + _src_core.traceback_util, + ), + "tuple_delete": ( + "jax.core.tuple_delete is deprecated. Use tuple_delete = lambda t, i: (*t[:i], *t[i+1:])", + _src_core.tuple_delete, + ), + "tuple_insert": ( + "jax.core.tuple_insert is deprecated. Use tuple_insert = lambda t, v, i: (*t[:i], v, *t[i:])", + _src_core.tuple_insert, + ), + "zip": ( + "jax.core.zip is deprecated. Use the built-in zip function.", + _src_core.zip, + ), +} + +import typing +if typing.TYPE_CHECKING: + DimSize = _src_core.DimSize + Shape = _src_core.Shape + TracerArrayConversionError = _src_core.TracerArrayConversionError + TracerIntegerConversionError = _src_core.TracerIntegerConversionError + UnexpectedTracerError = _src_core.UnexpectedTracerError + as_hashable_function = _src_core.as_hashable_function + collections = _src_core.collections + dtypes = _src_core.dtypes + lu = _src_core.lu + map = _src_core.map + namedtuple = _src_core.namedtuple + partial = _src_core.partial + pp = _src_core.pp + ref = _src_core.ref + safe_map = _src_core.safe_map + safe_zip = _src_core.safe_zip + source_info_util = _src_core.source_info_util + total_ordering = _src_core.total_ordering + traceback_util = _src_core.traceback_util + tuple_delete = _src_core.tuple_delete + tuple_insert = _src_core.tuple_insert + zip = _src_core.zip +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _src_core