Skip to content

Commit

Permalink
Clarifying docstring for devices argument of pmap.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 383486168
  • Loading branch information
james-martens authored and jax authors committed Jul 7, 2021
1 parent 56087dc commit f925b62
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,8 +1423,10 @@ def pmap(
static. Defaults to ().
devices: This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). If specified, the size of the mapped axis
must be equal to the number of local devices in the sequence. Nested
retrieved via jax.devices()). Must be given identically for each process
in multi-process settings (and will therefore include devices across
processes). If specified, the size of the mapped axis must be equal to
the number of devices in the sequence local to the given process. Nested
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
:py:func:`pmap` are not yet supported.
backend: This is an experimental feature and the API is likely to change.
Expand Down

0 comments on commit f925b62

Please sign in to comment.