Skip to content

Commit

Permalink
Add model checkpointing to push_to_hub and PushToHubCallback (#14492)
Browse files Browse the repository at this point in the history
* Add checkpointing to push_to_hub and PushToHubCallback

* Add checkpoint loading

* Add missing default value

* Correct method name

* make style

* Moving everything to the right location

* make style

* Revert changes to file_utils.py

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding docstrings and comments to clarify code

* make style

* Fix organization positional arg

* Fix load_repo_checkpoint to no longer accidentally create empty repos

* make style

* Remove unnecessary 'organization' argument in load_repo_checkpoint

* Avoid private `_create_or_get_repo` method

* make style

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
Rocketknight1 and sgugger committed Nov 29, 2021
1 parent 8332327 commit 6fc38ad
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from time import sleep
from typing import Optional, Union
Expand All @@ -23,6 +24,7 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
):
"""
output_dir (:obj:`str`):
Expand All @@ -48,8 +50,13 @@ def __init__(
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
checkpoint (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when `save_strategy` is `epoch`.
"""
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
Expand All @@ -65,6 +72,7 @@ def __init__(
self.repo = Repository(str(output_dir), clone_from=hub_model_id)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint

def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
Expand All @@ -84,6 +92,9 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model._save_checkpoint(checkpoint_dir, epoch)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
Expand Down
70 changes: 70 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import os
import pickle
import re
import warnings
from typing import Dict, List, Optional, Union
Expand All @@ -30,6 +31,8 @@
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format

from huggingface_hub import Repository, list_repo_files

from .configuration_utils import PretrainedConfig
from .file_utils import (
DUMMY_INPUTS,
Expand Down Expand Up @@ -753,6 +756,73 @@ def get_input_embeddings(self) -> tf.keras.layers.Layer:
else:
raise NotImplementedError

def _save_checkpoint(self, checkpoint_dir, epoch):
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
# We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
# state for us, because it requires special handling for objects like custom losses, which we use
# internally and which users are likely to use too
weights_path = os.path.join(checkpoint_dir, "weights.h5")
self.save_weights(weights_path)
extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
with open(extra_data_path, "wb") as f:
pickle.dump(extra_data, f)

def load_repo_checkpoint(self, repo_path_or_name):
"""
Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
the checkpoint was made.
Args:
repo_path_or_name (:obj:`str`):
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder).
Returns:
:obj:`dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
"""
if getattr(self, "optimizer", None) is None:
raise RuntimeError(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if not os.path.isdir(repo_path_or_name):
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files = list_repo_files(repo_path_or_name)
for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
if file not in repo_files:
raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
if "/" not in repo_path_or_name:
model_id = repo_path_or_name
repo_path_or_name = self.get_full_repo_name(repo_path_or_name)
else:
model_id = repo_path_or_name.split("/")[-1]
repo = Repository(model_id, clone_from=f"https://huggingface.co/{repo_path_or_name}")
local_dir = repo.local_dir
else:
local_dir = repo_path_or_name

# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir = os.path.join(local_dir, "checkpoint")
weights_file = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.isfile(weights_file):
raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
if not os.path.isfile(extra_data_file):
raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

# Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
# The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
self.load_weights(weights_file)
with open(extra_data_file, "rb") as f:
extra_data = pickle.load(f)
self.optimizer.set_weights(extra_data["optimizer_state"])

# Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
# set it directly, but the user can pass it to fit().
return {"epoch": extra_data["epoch"]}

def compile(
self,
optimizer="rmsprop",
Expand Down

0 comments on commit 6fc38ad

Please sign in to comment.