Skip to content

Commit

Permalink
wrap all resumes in cpu context
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Jun 6, 2023
1 parent fc7c8b7 commit 48ef1e3
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from time import time
from typing import Any, Callable, Tuple, Union

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -87,20 +88,18 @@ def restore_checkpoint_params(
"""
Restores the most recent parameter dict
"""
import flax
restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="params_")
with jax.default_device(jax.devices("cpu")[0]):
params = flax.core.freeze(restored["params"])
params = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),param_spec, params)

return params, restored["step"]
return params, restored["step"]

def restore_checkpoint_opt(opt_spec: Any, workdir: str
) -> Tuple[Any, Any, int]:
"""
Restores the most recent opt state dict.
"""
import flax
restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="opt_")

with jax.default_device(jax.devices("cpu")[0]):
Expand All @@ -109,24 +108,24 @@ def restore_checkpoint_opt(opt_spec: Any, workdir: str
)
mu_pytree = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),opt_spec[1][0].mu, flax.core.freeze(mu_pytree))

count_pytree = jax.tree_map(
lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["count"]
)
count_pytree = jax.tree_map(
lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["count"]
)

restoredlionstate = optax.ScaleByLionState(
count_pytree, flax.core.FrozenDict(mu_pytree)
)
restoredlionstate = optax.ScaleByLionState(
count_pytree, flax.core.FrozenDict(mu_pytree)
)


opt_state = (
optax.EmptyState(),
(
restoredlionstate,
optax.MaskedState(inner_state=optax.EmptyState()),
optax.ScaleByScheduleState(count=jnp.array(restored["step"])),
),
)
return opt_state
opt_state = (
optax.EmptyState(),
(
restoredlionstate,
optax.MaskedState(inner_state=optax.EmptyState()),
optax.ScaleByScheduleState(count=jnp.array(restored["step"])),
),
)
return opt_state

def create_train_state(
rng: jax.random.PRNGKey,
Expand Down

0 comments on commit 48ef1e3

Please sign in to comment.