Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model checkpointing to push_to_hub and PushToHubCallback #14492

Merged
merged 19 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io
import json
import os
import pickle
import re
import shutil
import subprocess
Expand Down Expand Up @@ -2200,6 +2201,8 @@ def push_to_hub(
commit_message: Optional[str] = None,
organization: Optional[str] = None,
private: Optional[bool] = None,
checkpoint: Optional[bool] = False,
epoch: Optional[int] = -1,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
"""
Expand All @@ -2225,6 +2228,10 @@ def push_to_hub(
Organization in which you want to push your {object} (you must be a member of this organization).
private (:obj:`bool`, `optional`):
Whether or not the repository created should be private (requires a paying subscription).
checkpoint (:obj:`bool`, `optional`):
Whether to save a checkpoint (including epoch number and optimizer state) or not.
epoch (:obj:`int`, `optional`):
The current epoch number. Only used when saving checkpoints.
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
Expand Down Expand Up @@ -2274,7 +2281,11 @@ def push_to_hub(
use_auth_token=use_auth_token,
)
# Save the files in the cloned repo
if checkpoint:
checkpoint_dir = os.path.join(repo_path_or_name, "checkpoint")
self.save_checkpoint(checkpoint_dir, epoch)
self.save_pretrained(repo_path_or_name)

# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)

Expand All @@ -2284,6 +2295,19 @@ def push_to_hub(

return url

def save_checkpoint(self, checkpoint_dir, epoch):
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
# We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
# state for us, because it requires special handling for objects like custom losses, which we use
# internally and which users are likely to use too
weights_path = os.path.join(checkpoint_dir, "weights.h5")
self.save_weights(weights_path)
extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
with open(extra_data_path, "wb") as f:
pickle.dump(extra_data, f)

@staticmethod
def _get_repo_url_from_name(
repo_name: str,
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from time import sleep
from typing import Optional, Union
Expand All @@ -23,6 +24,7 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: Optional[bool] = False,
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
output_dir (:obj:`str`):
Expand All @@ -48,8 +50,13 @@ def __init__(
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
checkpoint (:obj:`bool`, `optional`):
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when `save_strategy` is `epoch`.
"""
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
if isinstance(save_strategy, str):
save_strategy = IntervalStrategy(save_strategy.lower())
self.save_strategy = save_strategy
Expand All @@ -65,6 +72,7 @@ def __init__(
self.repo = Repository(str(output_dir), clone_from=hub_model_id)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint

def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
Expand All @@ -84,6 +92,9 @@ def on_epoch_end(self, epoch, logs=None):
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model.save_checkpoint(checkpoint_dir, epoch)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
Expand Down