Skip to content

Commit

Permalink
Fix sphinx syntax error.
Browse files Browse the repository at this point in the history
  • Loading branch information
nouiz committed May 5, 2023
1 parent de57b4f commit 916ad35
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions jax/experimental/multihost_utils.py
Expand Up @@ -257,35 +257,31 @@ def host_local_array_to_global_array_impl(

def host_local_array_to_global_array(
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
"""Converts a host local value to a globally sharded `jax.Array`.
r"""Converts a host local value to a globally sharded jax.Array.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
You can use this function to transition to jax.Array. Using jax.Array with
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
inputs to pjit should be globally shaped.
If you are currently passing host local values to pjit, you can use this
function to convert your host local values to global Arrays and then pass that
to pjit.
Example usage:
to pjit. Example usage.
```
from jax.experimental import multihost_utils
>>> from jax.experimental import multihost_utils
global_inputs = multihost_utils.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
>>> global_inputs = multihost_utils.host_local_array_to_global_array(
>>> host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
>>> with mesh:
>>> global_out = pjitted_fun(global_inputs)
host_local_output = multihost_utils.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
>>> host_local_output = multihost_utils.global_array_to_host_local_array(
>>> global_out, mesh, out_pspecs)
Args:
local_inputs: A Pytree of host local values.
global_mesh: A ``jax.sharding.Mesh`` object.
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
global_mesh: A jax.sharding.Mesh object.
pspecs: A Pytree of jax.sharding.PartitionSpec's.
"""
flat_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
Expand Down Expand Up @@ -357,36 +353,32 @@ def global_array_to_host_local_array_impl(

def global_array_to_host_local_array(
global_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
"""Converts a global `jax.Array` to a host local `jax.Array`.
r"""Converts a global `jax.Array` to a host local `jax.Array`.
You can use this function to transition to `jax.Array`. Using `jax.Array` with
`pjit` has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped and the output from `pjit` will also
be globally shaped `jax.Array`s
pjit has the same semantics of using GDA with pjit i.e. all `jax.Array`
inputs to pjit should be globally shaped and the output from pjit will also
be globally shaped jax.Array's
You can use this function to convert the globally shaped `jax.Array` output
from pjit to host local values again so that the transition to jax.Array can
be a mechanical change.
be a mechanical change. Example usage
Example usage:
>>> from jax.experimental import multihost_utils
```
from jax.experimental import multihost_utils
global_inputs = multihost_utils.host_local_array_to_global_array(
host_local_inputs, global_mesh, in_pspecs)
>>> global_inputs = multihost_utils.host_local_array_to_global_array(
>>> host_local_inputs, global_mesh, in_pspecs)
with mesh:
global_out = pjitted_fun(global_inputs)
>>> with mesh:
>>> global_out = pjitted_fun(global_inputs)
host_local_output = multihost_utils.global_array_to_host_local_array(
global_out, mesh, out_pspecs)
```
>>> host_local_output = multihost_utils.global_array_to_host_local_array(
>>> global_out, mesh, out_pspecs)
Args:
global_inputs: A Pytree of global `jax.Array`s.
global_mesh: A ``jax.sharding.Mesh`` object.
pspecs: A Pytree of ``jax.sharding.PartitionSpec``s.
global_inputs: A Pytree of global jax.Array's.
global_mesh: A jax.sharding.Mesh object.
pspecs: A Pytree of jax.sharding.PartitionSpec's.
"""
flat_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
Expand Down

0 comments on commit 916ad35

Please sign in to comment.