Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def _unpmap(v):
def _init_training_state(
key: PRNGKey,
obs_size: int,
local_devices_to_use: int,
sac_network: sac_networks.SACNetworks,
alpha_optimizer: optax.GradientTransformation,
policy_optimizer: optax.GradientTransformation,
Expand Down Expand Up @@ -109,16 +108,7 @@ def _init_training_state(
alpha_params=log_alpha,
normalizer_params=normalizer_params,
)
devices = jax.local_devices()[:local_devices_to_use]
mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',))
sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded'))

def _replicate(x):
if isinstance(x, jax.Array):
return jax.device_put(jnp.stack([x] * len(devices)), sharding)
return jax.device_put(np.stack([x] * len(devices)), sharding)

return jax.tree_util.tree_map(_replicate, training_state)
return training_state


def train(
Expand Down Expand Up @@ -491,7 +481,6 @@ def training_epoch_with_timing(
training_state = _init_training_state(
key=global_key,
obs_size=obs_size,
local_devices_to_use=local_devices_to_use,
sac_network=sac_network,
alpha_optimizer=alpha_optimizer,
policy_optimizer=policy_optimizer,
Expand All @@ -506,6 +495,19 @@ def training_epoch_with_timing(
policy_params=params[1],
)

# Replicate training state across devices AFTER checkpoint restoration
# so that restored params have the correct per-device shape. Fixes #659.
devices = jax.local_devices()[:local_devices_to_use]
mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',))
sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded'))

def _replicate(x):
if isinstance(x, jax.Array):
return jax.device_put(jnp.stack([x] * len(devices)), sharding)
return jax.device_put(np.stack([x] * len(devices)), sharding)

training_state = jax.tree_util.tree_map(_replicate, training_state)

local_key, rb_key, env_key, eval_key = jax.random.split(local_key, 4)

# Env init
Expand Down Expand Up @@ -624,4 +626,4 @@ def training_epoch_with_timing(
pmap.assert_is_replicated(training_state)
logging.info('total steps: %s', total_steps)
pmap.synchronize_hosts()
return (make_policy, params, metrics)
return (make_policy, params, metrics)