From 48ef1e3a6214ed38dacef4fb48822b349c2dca34 Mon Sep 17 00:00:00 2001 From: Benjamin Fattori Date: Tue, 6 Jun 2023 17:52:44 +0100 Subject: [PATCH] wrap all resumes in cpu context --- main.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 4c89a14..17c6151 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ from time import time from typing import Any, Callable, Tuple, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -87,20 +88,18 @@ def restore_checkpoint_params( """ Restores the most recent parameter dict """ - import flax restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="params_") with jax.default_device(jax.devices("cpu")[0]): params = flax.core.freeze(restored["params"]) params = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),param_spec, params) - return params, restored["step"] + return params, restored["step"] def restore_checkpoint_opt(opt_spec: Any, workdir: str ) -> Tuple[Any, Any, int]: """ Restores the most recent opt state dict. """ - import flax restored = checkpoints.restore_checkpoint(workdir, target=None, prefix="opt_") with jax.default_device(jax.devices("cpu")[0]): @@ -109,24 +108,24 @@ def restore_checkpoint_opt(opt_spec: Any, workdir: str ) mu_pytree = jax.tree_map(lambda x, y: nn.Partitioned(value = jnp.array(y['value']), names = x, mesh = None),opt_spec[1][0].mu, flax.core.freeze(mu_pytree)) - count_pytree = jax.tree_map( - lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["count"] - ) + count_pytree = jax.tree_map( + lambda x: jnp.array(x), restored["opt_state"]["1"]["0"]["count"] + ) - restoredlionstate = optax.ScaleByLionState( - count_pytree, flax.core.FrozenDict(mu_pytree) - ) + restoredlionstate = optax.ScaleByLionState( + count_pytree, flax.core.FrozenDict(mu_pytree) + ) - opt_state = ( - optax.EmptyState(), - ( - restoredlionstate, - optax.MaskedState(inner_state=optax.EmptyState()), - optax.ScaleByScheduleState(count=jnp.array(restored["step"])), - ), - ) - return opt_state + opt_state = ( + optax.EmptyState(), + ( + restoredlionstate, + optax.MaskedState(inner_state=optax.EmptyState()), + optax.ScaleByScheduleState(count=jnp.array(restored["step"])), + ), + ) + return opt_state def create_train_state( rng: jax.random.PRNGKey,