Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and restoring checkpoint hangs in Jupyter Notebook. #263

Closed
zjniu opened this issue Mar 28, 2023 · 4 comments
Closed

Saving and restoring checkpoint hangs in Jupyter Notebook. #263

zjniu opened this issue Mar 28, 2023 · 4 comments

Comments

@zjniu
Copy link

zjniu commented Mar 28, 2023

I have a training loop as shown below. I am running Python 3.10.10 and the latest versions of JAX (0.4.7), Flax (0.6.7), and Orbax (0.1.6). I am having some issues with the restore and save commands leading to the code hanging in Jupyter Notebook. When I call the train_model function, the code block would freeze at either restore or save but resumes if I run another code block. I think it could potentially have something to do with asyncio, but I am not totally sure. I had recently switched over from flax.checkpoints, where this wasn't an issue. Any help on this would be appreciated!

#Flax imports
from flax import serialization
from flax.training import orbax_utils, train_state

...

def train_model(model_path, dataset_path, dataset_adjustment='normalize',
                epochs=200, random_seed=0, batch_size=4, learning_rate=0.001,
                warmup_epochs=10, decay_epochs=100, decay_rate=0.5, decay_transition_epochs=10,
                optimizer=None, loss_weights=None):

    model_path = Path(model_path)
    model_parent_path = model_path.parent
    model_name = model_path.stem
    checkpoint_path = model_parent_path.joinpath(f'{model_name}_ckpts')
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    batch_metrics_log_path = model_parent_path.joinpath(f'{model_name}_batch_metrics_log')
    epoch_metrics_log_path = model_parent_path.joinpath(f'{model_name}_epoch_metrics_log')

    if batch_metrics_log_path.is_file():
        with open(batch_metrics_log_path, 'r') as f_batch_metrics_log:
            batch_metrics_log = json.load(f_batch_metrics_log)
    else:
        batch_metrics_log = []
    if epoch_metrics_log_path.is_file():
        with open(epoch_metrics_log_path, 'r') as f_epoch_metrics_log:
            epoch_metrics_log = json.load(f_epoch_metrics_log)
    else:
        epoch_metrics_log = []

    print('Loading datasets...\n')
    ds = load_datasets(dataset_path, adjustment=dataset_adjustment)
    train_images_shape = ds['train']['images'].shape
    input_size = train_images_shape[1:3]
    coords_max_length = \
        max([len(coords) for coords in ds['train']['coords']] + [len(coords) for coords in ds['valid']['coords']])

    rng = random.PRNGKey(random_seed)

    warmup = [learning_rate * i / warmup_epochs for i in range(warmup_epochs)]
    constant = [learning_rate] * (epochs - warmup_epochs - decay_epochs)
    decay = [learning_rate * decay_rate ** np.ceil(i / decay_transition_epochs) for i in range(1, decay_epochs + 1)]
    schedule = warmup + constant + decay

    if optimizer is None:
        optimizer = partial(optax.adabelief, eps=1e-8)
    tx = optax.inject_hyperparams(optimizer)(learning_rate=learning_rate)

    if loss_weights is None:
        loss_weights = {
            'rmse': 0.4,
            'bce': 0.2,
            'smoothf1': 1
        }

    mgr_options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2)
    handlers = {'state': orbax.checkpoint.PyTreeCheckpointer()}
    ckpt_mgr = orbax.checkpoint.CheckpointManager(
        directory=checkpoint_path,
        checkpointers=handlers,
        options=mgr_options
    )

    if (next(checkpoint_path.iterdir(), None) is None) and model_path.is_file():
        print(f'Loading existing model weights from {model_path}...\n')
        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        variables = ckptr.restore(model_path, item=None)
    else:
        variables = None

    print('Creating new TrainState...\n')
    state = create_train_state(rng, input_size, tx, variables)
    latest_epoch = ckpt_mgr.latest_step()
    if latest_epoch is not None:
        print(f'Loading latest checkpoint from {checkpoint_path}...\n')
        restore_args = orbax_utils.restore_args_from_target(state, mesh=None)
        state = ckpt_mgr.restore(
            step=latest_epoch,
            items={'state': state},
            restore_kwargs={'state': {'restore_args': restore_args}}
        )['state']

    for epoch_learning_rate in schedule:

        state, batch_metrics, epoch_metrics = \
            train_epoch(state, ds, batch_size, loss_weights, epoch_learning_rate, input_size, coords_max_length)

        batch_metrics_log += batch_metrics
        epoch_metrics_log += [epoch_metrics]

        save_args = orbax_utils.save_args_from_target(state)
        ckpt_mgr.save(
            step=state.epoch,
            items={'state': state},
            save_kwargs={'state': {'save_args': save_args}}
        )

        with open(batch_metrics_log_path, 'w') as f_batch_metrics_log:
            json.dump(batch_metrics_log, f_batch_metrics_log, indent=4)
        with open(epoch_metrics_log_path, 'w') as f_epoch_metrics_log:
            json.dump(epoch_metrics_log, f_epoch_metrics_log, indent=4)

    variables = {'params': state.params, 'batch_stats': state.batch_stats, 'input_size': input_size}
    bytes_model = serialization.to_bytes(variables)

    with open(model_path, 'wb') as f_model:
        f_model.write(bytes_model)
@cpgaffney1
Copy link
Collaborator

I have a suspicion it may be due to a bad interaction between different dependencies. Could you try to strip out the extra stuff unrelated to checkpointing, without removing any imports and see if that works?

If it still hangs, you can start removing some imports, and the issue should resolve (at worst by the point where you're only importing Orbax).

Also, a few nits:
handlers = {'state': orbax.checkpoint.PyTreeCheckpointer()}: this should be checkpointers. Handler is a different concept in Orbax and will rapidly lead to bugs if you confuse the two.
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) is equivalent to orbax.checkpoint.PyTreeCheckpointer()

@zjniu
Copy link
Author

zjniu commented Mar 28, 2023

Thank you for the suggestions! It turns out

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('QtAgg')
%matplotlib qt

in my notebook were the lines of code causing the issue. Is there a way to solve this issue other than not importing?

@cpgaffney1
Copy link
Collaborator

I'm afraid your guess is as good as mine, since I'm not familiar with this particular backend, and there appears to be no problem with matplotlib without overriding the backend. Is it possible to skip that part?

@zjniu
Copy link
Author

zjniu commented Mar 29, 2023

Yep, matplotlib by itself is not the issue. The qt backend is more of a convenience, not at all necessary. I will disable it for now. Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants