Skip to content

Commit

Permalink
NamedSharding doesn't work in multihost
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Jun 6, 2023
1 parent b6e35d0 commit 69b38a9
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,6 @@ def compute_tokens_seen(absolute_step, max_context):
import gc
gc.collect()

named_sharding_params = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), param_spec)
named_sharding_opt = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), opt_state_spec)

params = jax.device_put(params, named_sharding_params)
opt_state = jax.device_put(opt_state, named_sharding_opt)


else:
raise NotImplementedError(
"Checkpointing not currently implemented for GPU."
Expand Down

0 comments on commit 69b38a9

Please sign in to comment.