Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model checkpointing to push_to_hub and PushToHubCallback #14492

Merged
merged 19 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from time import sleep
from typing import Optional, Union
Expand All @@ -23,6 +24,7 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: Optional[bool] = False,
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
output_dir (:obj:`str`):
Expand All @@ -48,8 +50,13 @@ def __init__(
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
checkpoint (:obj:`bool`, `optional`):
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when `save_strategy` is `epoch`.
"""
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
Expand All @@ -65,6 +72,7 @@ def __init__(
self.repo = Repository(str(output_dir), clone_from=hub_model_id)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint

def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
Expand All @@ -84,6 +92,9 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model._save_checkpoint(checkpoint_dir, epoch)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
Expand Down
34 changes: 34 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import os
import pickle
import re
import warnings
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -753,6 +754,39 @@ def get_input_embeddings(self) -> tf.keras.layers.Layer:
else:
raise NotImplementedError

def _save_checkpoint(self, checkpoint_dir, epoch):
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
# We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
# state for us, because it requires special handling for objects like custom losses, which we use
# internally and which users are likely to use too
weights_path = os.path.join(checkpoint_dir, "weights.h5")
self.save_weights(weights_path)
extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
with open(extra_data_path, "wb") as f:
pickle.dump(extra_data, f)

def load_repo_checkpoint(self, repo_path_or_name, organization=None):
if getattr(self, "optimizer", None) is None:
raise RuntimeError(
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
repo = self._create_or_get_repo(repo_path_or_name, organization)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we said to use Repository here to avoid creating a new repo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this, but it resulted in some code duplication in order to make everything work - I kept it as-is but added a comment to explain instead. If you'd still prefer to avoid the call, let me know!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see what @LysandreJik thinks then.

Copy link
Member

@LysandreJik LysandreJik Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it won't actually create a remote repository if it doesn't exist? If you're certain of it it's fine for me to go with self._create_or_get_repo but I'm quite unsure that it won't create a remote repo.

Also which method are you using, is it this method here?

def _create_or_get_repo(
cls,
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> Repository:
if repo_path_or_name is None and repo_url is None:
raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
if use_auth_token is None and repo_url is None:
use_auth_token = True
if repo_path_or_name is None:
repo_path_or_name = repo_url.split("/")[-1]
if repo_url is None and not os.path.exists(repo_path_or_name):
repo_name = Path(repo_path_or_name).name
repo_url = cls._get_repo_url_from_name(
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
)
# Create a working directory if it does not exist.
if not os.path.exists(repo_path_or_name):
os.makedirs(repo_path_or_name)
repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
repo.git_pull()
return repo

Because if so, don't call it with positional arguments like you do here as you're passing the organization parameter to repo_url.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot on the organization positional arg, fixed!

Also, the _create_or_get_repo method just does this, after sanity-checking the arguments:

repo = Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token)
repo.git_pull()
return repo

So I think as long as we don't push it afterwards, we won't create a new remote repo if one doesn't exist already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this discussion is real proof we need to just put the code used here and not try to use some magic of the Mixin. The library always prefers an explicit approach at the cost of duplicate code, and I think this is an instance where we should just apply the code needed.

The Mixin might also change in the feature, following the RFC on huggingface_hub, so those private methods might actually introduce some behavior we don't want in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the Repository(repo_path_or_name, clone_from=repo_url, use_auth_token=use_auth_token) line actually creates the repo if you try to clone it but it doesn't exist:

https://github.com/huggingface/huggingface_hub/blob/10c69146969ad7f9e1add075c1ef4ec15e42e85f/src/huggingface_hub/repository.py#L542-L549

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The push event is unrelated to the creation of the repository. No push to a remote repository can be done if the repository does not exist. The repository creation happens at the repository creation, not at the repository update (push).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I'm sorry, I didn't realize that's how it worked! And yes, you're right, this is probably a sign that we need to be explicit. I'll try to figure out some code that just checks without accidentally creating a new repo.

checkpoint_dir = os.path.join(repo.local_dir, "checkpoint")
weights_file = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.isfile(weights_file):
raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
if not os.path.isfile(extra_data_file):
raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")
self.load_weights(weights_file)
with open(extra_data_file, "rb") as f:
extra_data = pickle.load(f)
self.optimizer.set_weights(extra_data["optimizer_state"])
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
return {"epoch": extra_data["epoch"]}

def compile(
self,
optimizer="rmsprop",
Expand Down