Skip to content

Commit

Permalink
Delete soft_pmap as it has no users. Please use pjit or xmap if y…
Browse files Browse the repository at this point in the history
…ou do want soft_pmap.

`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.

PiperOrigin-RevId: 474145090
  • Loading branch information
yashk2810 authored and jax authors committed Sep 13, 2022
1 parent dc7db8d commit da90234
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 147 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Breaking changes
* `jax._src` is no longer imported into the from the public `jax` namespace.
This may break users that were using JAX internals.
* `jax.soft_pmap` has been deleted. Please use `pjit` or `xmap` instead.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period
would have been provided.

## jax 0.3.17 (Aug 31, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
Expand Down
1 change: 0 additions & 1 deletion jax/__init__.py
Expand Up @@ -112,7 +112,6 @@
xla, # TODO(phawkins): update users to avoid this.
xla_computation as xla_computation,
)
from jax.experimental.maps import soft_pmap as soft_pmap
from jax.version import __version__ as __version__
from jax.version import __version_info__ as __version_info__

Expand Down
23 changes: 0 additions & 23 deletions jax/experimental/maps.py
Expand Up @@ -1899,29 +1899,6 @@ class NoQuotesStr(str):
__repr__ = str.__str__


# -------- soft_pmap --------

def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name

if any(axis != 0 for axis in tree_leaves(in_axes)):
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
proxy = object()
in_axes = _replace_nones(proxy, in_axes)
in_axes = tree_map(lambda i: {i: axis_name} if i is not proxy else {}, in_axes)


@wraps(fun)
def f_pmapped(*args, **kwargs):
mesh_devices = np.array(xb.local_devices())
with Mesh(mesh_devices, ['devices']):
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
axis_resources={axis_name: 'devices'})(*args, **kwargs)
return f_pmapped

# -------- config flags --------

def _thread_local_flag_unsupported(_):
Expand Down
124 changes: 1 addition & 123 deletions tests/pmap_test.py
Expand Up @@ -38,7 +38,7 @@
from jax._src import api as src_api
from jax import random
from jax.core import ShapedArray
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
from jax._src import device_array
Expand Down Expand Up @@ -1545,128 +1545,6 @@ def testReshardInput(self):
r_db = r.device_buffers
self.assertEqual(len(r_db), 6)

@ignore_xmap_warning()
def testSoftPmapBatchMatmul(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapBatchMatmulJit(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapPsumConstant(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
def f(_):
return lax.psum(1, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
expected = n * np.ones(n)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapPsum(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
def f(x):
return x / lax.psum(x, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
expected = np.ones(n) / n
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapAxisIndex(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
def f(x):
return x * lax.axis_index('i')
ans = soft_pmap(f, 'i')(2 * jnp.ones(n, dtype='int32'))
expected = 2 * np.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapOfJit(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
def f(x):
return 3 * x
ans = soft_pmap(jit(f), 'i')(np.arange(n))
expected = 3 * np.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
@unittest.skip("not implemented") # TODO(mattjj): re-implement
def testSoftPmapNested(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()

@partial(soft_pmap, axis_name='i')
@partial(soft_pmap, axis_name='j')
def f(x):
i_size = lax.psum(1, 'i')
return x + lax.axis_index('i') + i_size * lax.axis_index('j')

ans = f(jnp.zeros((n, n)))
expected = np.arange(n ** 2).reshape(n, n).T
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
@unittest.skip("not implemented") # TODO(mattjj): re-implement
def testGradOfSoftPmap(self):
n = 4 * jax.device_count()

@partial(soft_pmap, axis_name='i')
def f(x):
return x * lax.axis_index('i')

ans = grad(lambda x: jnp.sum(f(x)))(jnp.zeros((n, n)))
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
self.assertAllClose(ans, expected, check_dtypes=False)

@ignore_xmap_warning()
def testSoftPmapDevicePersistence(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
device_count = jax.device_count()
shape = (2 * 2 * device_count, 2, 3)

# check that we can maintain device persistence across calls
x = np.arange(prod(shape)).reshape(shape)
x = soft_pmap(lambda x: x)(x)
self.assertIsInstance(x, pxla.ShardedDeviceArray)
x._npy_value = np.float32(np.nan) # can't be coerced to ndarray for xfer
x = soft_pmap(lambda x: x)(x) # doesn't crash
self.assertIsInstance(x, pxla.ShardedDeviceArray)

@unittest.skip("the underlying code here is broken") # TODO(mattjj)
def testSoftPmapAllToAll(self):
if config.jax_array:
raise unittest.SkipTest('Does not work with `Array`.')
n = 4 * jax.device_count()
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
expected = np.arange(n ** 2).reshape(n, n).T
self.assertAllClose(ans, expected, check_dtypes=False)

def testShardedDeviceArrayBlockUntilReady(self):
x = np.arange(jax.device_count())
x = self.pmap(lambda x: x)(x)
Expand Down

0 comments on commit da90234

Please sign in to comment.