Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model checkpointing to push_to_hub and PushToHubCallback #14492

Merged
merged 19 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from time import sleep
from typing import Optional, Union
Expand All @@ -23,6 +24,7 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
):
"""
output_dir (:obj:`str`):
Expand All @@ -48,8 +50,13 @@ def __init__(
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
checkpoint (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when `save_strategy` is `epoch`.
"""
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
Expand All @@ -65,6 +72,7 @@ def __init__(
self.repo = Repository(str(output_dir), clone_from=hub_model_id)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint

def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
Expand All @@ -84,6 +92,9 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model._save_checkpoint(checkpoint_dir, epoch)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
Expand Down
70 changes: 70 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import os
import pickle
import re
import warnings
from typing import Dict, List, Optional, Union
Expand All @@ -30,6 +31,8 @@
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format

from huggingface_hub import Repository, list_repo_files

from .configuration_utils import PretrainedConfig
from .file_utils import (
DUMMY_INPUTS,
Expand Down Expand Up @@ -753,6 +756,73 @@ def get_input_embeddings(self) -> tf.keras.layers.Layer:
else:
raise NotImplementedError

def _save_checkpoint(self, checkpoint_dir, epoch):
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
# We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
# state for us, because it requires special handling for objects like custom losses, which we use
# internally and which users are likely to use too
weights_path = os.path.join(checkpoint_dir, "weights.h5")
self.save_weights(weights_path)
extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
with open(extra_data_path, "wb") as f:
pickle.dump(extra_data, f)

def load_repo_checkpoint(self, repo_path_or_name):
"""
Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
the checkpoint was made.

Args:
repo_path_or_name (:obj:`str`):
Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder).

Returns:
:obj:`dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
"""
if getattr(self, "optimizer", None) is None:
raise RuntimeError(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if not os.path.isdir(repo_path_or_name):
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files = list_repo_files(repo_path_or_name)
for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
if file not in repo_files:
raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
if "/" not in repo_path_or_name:
model_id = repo_path_or_name
repo_path_or_name = self.get_full_repo_name(repo_path_or_name)
else:
model_id = repo_path_or_name.split("/")[-1]
repo = Repository(model_id, clone_from=f"https://huggingface.co/{repo_path_or_name}")
local_dir = repo.local_dir
else:
local_dir = repo_path_or_name

# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir = os.path.join(local_dir, "checkpoint")
weights_file = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.isfile(weights_file):
raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
if not os.path.isfile(extra_data_file):
raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

# Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
# The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
self.load_weights(weights_file)
with open(extra_data_file, "rb") as f:
extra_data = pickle.load(f)
self.optimizer.set_weights(extra_data["optimizer_state"])

# Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
# set it directly, but the user can pass it to fit().
return {"epoch": extra_data["epoch"]}

def compile(
self,
optimizer="rmsprop",
Expand Down