Skip to content

Commit

Permalink
CLI: Add flag to push TF weights directly into main (#17720)
Browse files Browse the repository at this point in the history
* Add flag to push weights directly into main
  • Loading branch information
gante committed Jun 15, 2022
1 parent 6ebeeee commit c3c62b5
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/transformers/commands/pt_to_tf.py
Expand Up @@ -45,7 +45,7 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights)
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push)


class PTtoTFCommand(BaseTransformersCLICommand):
Expand Down Expand Up @@ -76,13 +76,18 @@ def register_subcommand(parser: ArgumentParser):
default="",
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
)
train_parser.add_argument(
"--new-weights",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
)
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.add_argument(
"--new-weights",
"--push",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
help="Optional flag to push the weights directly to `main` (requires permissions)",
)
train_parser.set_defaults(func=convert_command_factory)

Expand Down Expand Up @@ -129,12 +134,13 @@ def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):

return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._no_pr = no_pr
self._new_weights = new_weights
self._no_pr = no_pr
self._push = push

def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
Expand Down Expand Up @@ -234,7 +240,12 @@ def run(self):
)
)

if not self._no_pr:
if self._push:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Add TF weights")
repo.git_push(blocking=True) # this prints a progress bar with the upload
self._logger.warn(f"TF weights pushed into {self._model_name}")
elif not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
Expand Down

0 comments on commit c3c62b5

Please sign in to comment.