From 69b38a99a7ac446a97d1bfbc49eab526996f1ef7 Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Tue, 6 Jun 2023 15:15:57 +0100 Subject: [PATCH] NamedSharding doesn't work in multihost --- main.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/main.py b/main.py index 12b3de3..bfa0142 100644 --- a/main.py +++ b/main.py @@ -293,13 +293,6 @@ def compute_tokens_seen(absolute_step, max_context): import gc gc.collect() - named_sharding_params = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), param_spec) - named_sharding_opt = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), opt_state_spec) - - params = jax.device_put(params, named_sharding_params) - opt_state = jax.device_put(opt_state, named_sharding_opt) - - else: raise NotImplementedError( "Checkpointing not currently implemented for GPU."