Skip to content

Commit

Permalink
MAINT Drop underscore from the name of externally-referenced state ob…
Browse files Browse the repository at this point in the history
…jects
  • Loading branch information
superbobry committed Oct 13, 2023
1 parent 16061e6 commit f9087ab
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 33 deletions.
12 changes: 6 additions & 6 deletions jax/_src/maps.py
Expand Up @@ -670,7 +670,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
name=name),
source_info_util.new_source_info(), resource_env, {})
jaxpr = plan.subst_axes_with_resources(jaxpr)
use_spmd_lowering = _SPMD_LOWERING.value
use_spmd_lowering = SPMD_LOWERING.value
ensure_fixed_sharding = _ENSURE_FIXED_SHARDING.value
if use_spmd_lowering and ensure_fixed_sharding:
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
Expand All @@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
tiling_method: pxla.TilingMethod
if _SPMD_LOWERING_MANUAL.value:
if SPMD_LOWERING_MANUAL.value:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
Expand Down Expand Up @@ -1284,7 +1284,7 @@ def out_axes_transform(out_axes):

def _xmap_lowering_rule(ctx, *args, **kwargs):
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
if _SPMD_LOWERING_MANUAL.value:
if SPMD_LOWERING_MANUAL.value:
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
else:
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
Expand Down Expand Up @@ -1839,21 +1839,21 @@ def _clear_compilation_cache(_):

def _ensure_spmd_and(f):
def update(v):
if v and not _SPMD_LOWERING.value:
if v and not SPMD_LOWERING.value:
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
return f(v)
return update


_SPMD_LOWERING = config.define_bool_state(
SPMD_LOWERING = config.define_bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmap computations will be compiled through "
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
_SPMD_LOWERING_MANUAL = config.define_bool_state(
SPMD_LOWERING_MANUAL = config.define_bool_state(
name="experimental_xmap_spmd_lowering_manual",
default=False,
help=("When set, multi-device xmap computations will be compiled using "
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/test_util.py
Expand Up @@ -67,7 +67,7 @@
'Describes the device under test in case special consideration is required.'
)

_NUM_GENERATED_CASES = config.DEFINE_integer(
NUM_GENERATED_CASES = config.DEFINE_integer(
'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
Expand Down Expand Up @@ -762,7 +762,7 @@ def assert_dot_preferred_element_type(expected, fun, *args, **kwargs):

def cases_from_gens(*gens):
sizes = [1, 3, 10]
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
cases_per_size = int(NUM_GENERATED_CASES.value / len(sizes)) + 1
for size in sizes:
for i in range(cases_per_size):
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
Expand All @@ -775,7 +775,7 @@ def choose_one(x):
if not isinstance(x, (list, tuple)):
x = list(x)
return [x[rng.randint(len(x))]]
while (len(seen) < _NUM_GENERATED_CASES.value and
while (len(seen) < NUM_GENERATED_CASES.value and
retries < _MAX_CASES_SAMPLING_RETRIES.value):
retries += 1
cases = list(gen(choose_one))
Expand Down Expand Up @@ -804,7 +804,7 @@ def sample_product_testcases(*args, **kw):
kw = [(k, list(v)) for k, v in kw.items()]
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
testcases = []
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
for i in _choice(n, min(n, NUM_GENERATED_CASES.value)):
testcase = {}
for a in args:
testcase.update(a[i % len(a)])
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/sharding_test.py
Expand Up @@ -77,7 +77,7 @@ def setUpModule():
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)


Expand Down
2 changes: 1 addition & 1 deletion tests/multiprocess_gpu_test.py
Expand Up @@ -259,7 +259,7 @@ def create_2d_non_contiguous_mesh(self):

def setUp(self):
super().setUp()
self.xmap_spmd_lowering_enabled = maps._SPMD_LOWERING.value
self.xmap_spmd_lowering_enabled = maps.SPMD_LOWERING.value
jax.config.update("experimental_xmap_spmd_lowering", True)

def tearDown(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/pjit_test.py
Expand Up @@ -78,7 +78,7 @@ def setUpModule():
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
prev_spmd_lowering_flag = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)

def tearDownModule():
Expand Down
4 changes: 2 additions & 2 deletions tests/python_callback_test.py
Expand Up @@ -632,8 +632,8 @@ def test_can_pjit_pure_callback_under_hard_xmap(self):
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')

spmd_lowering = maps._SPMD_LOWERING.value
spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
spmd_lowering = maps.SPMD_LOWERING.value
spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
try:
Expand Down
12 changes: 6 additions & 6 deletions tests/shard_map_test.py
Expand Up @@ -1382,7 +1382,7 @@ def make_mesh(mesh_shape):
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))

@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
Expand All @@ -1391,7 +1391,7 @@ def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
self.assertAllClose(expected, out, check_dtypes=False)

@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
Expand All @@ -1401,7 +1401,7 @@ def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):

@parameterized.named_parameters(
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(jtu._NUM_GENERATED_CASES.value, sample_shmap)
for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
Expand All @@ -1414,7 +1414,7 @@ def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)
Expand All @@ -1433,7 +1433,7 @@ def g(*args):
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value,
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
Expand All @@ -1456,7 +1456,7 @@ def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)

@parameterized.named_parameters(
sample(jtu._NUM_GENERATED_CASES.value,
sample(jtu.NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)
Expand Down
12 changes: 6 additions & 6 deletions tests/state_test.py
Expand Up @@ -831,7 +831,7 @@ class StateHypothesisTest(jtu.JaxTestCase):

@hp.given(get_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_get_vmap(self, get_vmap_param: GetVmapParams):

indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
Expand Down Expand Up @@ -870,7 +870,7 @@ def f(ref, *non_slice_idx):

@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_set_vmap(self, set_vmap_param: SetVmapParams):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Scatter is nondeterministic on GPU")
Expand Down Expand Up @@ -915,7 +915,7 @@ def f(ref, val, *non_slice_idx):

@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):

indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
Expand Down Expand Up @@ -1538,7 +1538,7 @@ class RunStateHypothesisTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_jvp(self, data):

spec = data.draw(func_spec())
Expand All @@ -1563,7 +1563,7 @@ def ref(x):
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_linearize(self, data):

spec = data.draw(func_spec())
Expand All @@ -1589,7 +1589,7 @@ def ref(x):
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=jtu._NUM_GENERATED_CASES.value)
max_examples=jtu.NUM_GENERATED_CASES.value)
def test_vjp(self, data):

spec = data.draw(func_spec())
Expand Down
12 changes: 6 additions & 6 deletions tests/xmap_test.py
Expand Up @@ -246,7 +246,7 @@ class XMapTestCase(jtu.BufferDonationTestCase):
class SPMDTestMixin:
def setUp(self):
super().setUp()
self.spmd_lowering = maps._SPMD_LOWERING.value
self.spmd_lowering = maps.SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)

def tearDown(self):
Expand All @@ -258,8 +258,8 @@ def setUp(self):
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
raise SkipTest
super().setUp()
self.spmd_lowering = maps._SPMD_LOWERING.value
self.spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
self.spmd_lowering = maps.SPMD_LOWERING.value
self.spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)

Expand Down Expand Up @@ -436,7 +436,7 @@ def h(y):
m_size = math.prod([2] + [2] * (len(mesh) - 2))
self.assertListEqual(y_op_sharding.tile_assignment_dimensions(),
[2, 1, 1, m_size])
if maps._SPMD_LOWERING.value:
if maps.SPMD_LOWERING.value:
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
# Make sure that there are non-partial sharding specs in the HLO
if xla_extension_version >= 180:
Expand Down Expand Up @@ -749,7 +749,7 @@ def testLowerPartitionsAttribute(self):
axis_resources={'i': 'x'})
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
hlo = f.lower(x).as_text(dialect='stablehlo')
if maps._SPMD_LOWERING.value:
if maps.SPMD_LOWERING.value:
self.assertIn("mhlo.num_partitions = 2", hlo)
self.assertIn("mhlo.num_replicas = 1", hlo)
else:
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def testGatherPositional(self):

@jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources):
if axis_resources and not maps._SPMD_LOWERING.value:
if axis_resources and not maps.SPMD_LOWERING.value:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
y = jnp.arange(35).reshape((5, 7)) % 3
Expand Down

0 comments on commit f9087ab

Please sign in to comment.