Skip to content

Commit

Permalink
use CPU context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Jun 6, 2023
1 parent 6a94620 commit 18e7f33
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ def restore_checkpoint_params(
"""
import flax
restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="params_")
params = flax.core.freeze(restored["params"])
params = jax.device_get(params)

params = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),param_spec, params)
with jax.default_device(jax.devices("cpu")[0]):
params = flax.core.freeze(restored["params"])
params = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),param_spec, params)

return params, restored["step"]

Expand All @@ -104,9 +103,10 @@ def restore_checkpoint_opt(opt_spec: Any, workdir: str
import flax
restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="opt_")

mu_pytree = jax.tree_map(
lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["mu"]
)
with jax.default_device(jax.devices("cpu")[0]):
mu_pytree = jax.tree_map(
lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["mu"]
)
mu_pytree = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),opt_spec[1][0].mu, flax.core.freeze(mu_pytree))

count_pytree = jax.tree_map(
Expand Down

0 comments on commit 18e7f33

Please sign in to comment.