Skip to content

Commit

Permalink
jax.core: deprecate some inadvertent exports
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 11, 2023
1 parent 9ae5a43 commit e0944c9
Showing 1 changed file with 124 additions and 24 deletions.
148 changes: 124 additions & 24 deletions jax/core.py
Expand Up @@ -15,8 +15,6 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

from __future__ import annotations

from jax._src.core import (
AbstractToken as AbstractToken,
AbstractValue as AbstractValue,
Expand All @@ -27,7 +25,6 @@
ConcreteArray as ConcreteArray,
ConcretizationTypeError as ConcretizationTypeError,
DShapedArray as DShapedArray,
DimSize as DimSize,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
Expand All @@ -50,7 +47,6 @@
OutputType as OutputType,
ParamDict as ParamDict,
Primitive as Primitive,
Shape as Shape,
ShapedArray as ShapedArray,
Sublevel as Sublevel,
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
Expand All @@ -60,15 +56,11 @@
TraceStack as TraceStack,
TraceState as TraceState,
Tracer as Tracer,
TracerArrayConversionError as TracerArrayConversionError,
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
UnshapedArray as UnshapedArray,
Value as Value,
Var as Var,
abstract_token as abstract_token,
apply_todos as apply_todos,
as_hashable_function as as_hashable_function,
as_named_shape as as_named_shape,
aval_mapping_handlers as aval_mapping_handlers,
axis_frame as axis_frame,
Expand All @@ -82,7 +74,6 @@
check_type as check_type,
check_valid_jaxtype as check_valid_jaxtype,
closed_call_p as closed_call_p,
collections as collections,
concrete_aval as concrete_aval,
concrete_or_error as concrete_or_error,
concretization_function_error as concretization_function_error,
Expand All @@ -92,7 +83,6 @@
definitely_equal as definitely_equal, # TODO(necula): remove this API
dimension_as_value as dimension_as_value, # TODO(necula): remove this API
do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
dtypes as dtypes,
ensure_compile_time_eval as ensure_compile_time_eval,
escaped_tracer_error as escaped_tracer_error,
eval_context as eval_context,
Expand All @@ -114,22 +104,17 @@
lattice_join as lattice_join,
leaked_tracer_error as leaked_tracer_error,
literalable_types as literalable_types,
lu as lu,
map as map,
map_bind as map_bind,
map_bind_with_continuation as map_bind_with_continuation,
mapped_aval as mapped_aval,
maybe_find_leaked_tracers as maybe_find_leaked_tracers,
namedtuple as namedtuple,
new_base_main as new_base_main,
new_jaxpr_eqn as new_jaxpr_eqn,
new_main as new_main,
new_sublevel as new_sublevel,
no_axis_name as no_axis_name,
no_effects as no_effects,
outfeed_primitives as outfeed_primitives,
partial as partial,
pp as pp,
pp_aval as pp_aval,
pp_eqn as pp_eqn,
pp_eqn_rules as pp_eqn_rules,
Expand All @@ -150,11 +135,7 @@
raise_as_much_as_possible as raise_as_much_as_possible,
raise_to_shaped as raise_to_shaped,
raise_to_shaped_mappings as raise_to_shaped_mappings,
ref as ref,
reset_trace_state as reset_trace_state,
safe_map as safe_map,
safe_zip as safe_zip,
source_info_util as source_info_util,
stash_axis_env as stash_axis_env,
str_eqn_compact as str_eqn_compact,
subjaxprs as subjaxprs,
Expand All @@ -165,20 +146,139 @@
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
thread_local_state as thread_local_state,
token as token,
total_ordering as total_ordering,
trace_state_clean as trace_state_clean,
traceback_util as traceback_util,
traverse_jaxpr_params as traverse_jaxpr_params,
tuple_delete as tuple_delete,
tuple_insert as tuple_insert,
typecheck as typecheck,
typecompat as typecompat,
typematch as typematch,
unmapped_aval as unmapped_aval,
used_axis_names as used_axis_names,
used_axis_names_jaxpr as used_axis_names_jaxpr,
valid_jaxtype as valid_jaxtype,
zip as zip,
)

symbolic_equal_dim = definitely_equal # TODO(necula): remove this API

from jax._src import core as _src_core
_deprecations = {
# Added Oct 11, 2023:
"DimSize": (
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
_src_core.DimSize,
),
"Shape": (
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
_src_core.Shape,
),
"TracerArrayConversionError": (
"jax.core.TracerArrayConversionError is deprecated. Use jax.errors.TracerArrayConversionError",
_src_core.TracerArrayConversionError,
),
"TracerIntegerConversionError": (
"jax.core.TracerIntegerConversionError is deprecated. Use jax.errors.TracerIntegerConversionError",
_src_core.TracerIntegerConversionError,
),
"UnexpectedTracerError": (
"jax.core.UnexpectedTracerError is deprecated. Use jax.errors.UnexpectedTracerError",
_src_core.UnexpectedTracerError,
),
"as_hashable_function": (
"jax.core.as_hashable_function is deprecated. Use jax.util.as_hashable_function directly.",
_src_core.as_hashable_function,
),
"collections": (
"jax.core.collections is deprecated. Use the collections module directly.",
_src_core.collections,
),
"dtypes": (
"jax.core.dtypes is deprecated. Use jax.dtypes directly.",
_src_core.dtypes,
),
"lu": (
"jax.core.lu is deprecated. Use lu = jax.extend.linear_util",
_src_core.lu,
),
"map": (
"jax.core.map is deprecated. Use the built-in map function.",
_src_core.map,
),
"namedtuple": (
"jax.core.namedtuple is deprecated. Use collections.namedtuple directly.",
_src_core.namedtuple,
),
"partial": (
"jax.core.partial is deprecated. Use functools.partial directly.",
_src_core.partial,
),
"pp": (
"jax.core.pp is deprecated. jax._src.pretty_printer is a non-public API.",
_src_core.pp,
),
"ref": (
"jax.core.ref is deprecated. Use weakref.ref directly.",
_src_core.ref,
),
"safe_map": (
"jax.core.safe_map is deprecated. Use jax.util.safe_map directly.",
_src_core.safe_map,
),
"safe_zip": (
"jax.core.safe_zip is deprecated. Use jax.util.safe_zip directly.",
_src_core.safe_zip,
),
"source_info_util": (
"jax.core.source_info_util is deprecated. jax._src.source_info_util is a non-public API.",
_src_core.source_info_util,
),
"total_ordering": (
"jax.core.total_ordering is deprecated. Use functools.total_ordering directly.",
_src_core.total_ordering,
),
"traceback_util": (
"jax.core.traceback_util is deprecated. jax._src.traceback_util is a non-public API.",
_src_core.traceback_util,
),
"tuple_delete": (
"jax.core.tuple_delete is deprecated. Use tuple_delete = lambda t, i: (*t[:i], *t[i+1:])",
_src_core.tuple_delete,
),
"tuple_insert": (
"jax.core.tuple_insert is deprecated. Use tuple_insert = lambda t, v, i: (*t[:i], v, *t[i:])",
_src_core.tuple_insert,
),
"zip": (
"jax.core.zip is deprecated. Use the built-in zip function.",
_src_core.zip,
),
}

import typing
if typing.TYPE_CHECKING:
DimSize = _src_core.DimSize
Shape = _src_core.Shape
TracerArrayConversionError = _src_core.TracerArrayConversionError
TracerIntegerConversionError = _src_core.TracerIntegerConversionError
UnexpectedTracerError = _src_core.UnexpectedTracerError
as_hashable_function = _src_core.as_hashable_function
collections = _src_core.collections
dtypes = _src_core.dtypes
lu = _src_core.lu
map = _src_core.map
namedtuple = _src_core.namedtuple
partial = _src_core.partial
pp = _src_core.pp
ref = _src_core.ref
safe_map = _src_core.safe_map
safe_zip = _src_core.safe_zip
source_info_util = _src_core.source_info_util
total_ordering = _src_core.total_ordering
traceback_util = _src_core.traceback_util
tuple_delete = _src_core.tuple_delete
tuple_insert = _src_core.tuple_insert
zip = _src_core.zip
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _src_core

0 comments on commit e0944c9

Please sign in to comment.