Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify config filename in HfArgumentParser #6626

Merged
merged 1 commit into from
Aug 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
self.add_argument(field_name, **kwargs)

def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
Expand All @@ -107,6 +107,9 @@ def parse_args_into_dataclasses(
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
args_filename:
If not None, will uses this file instead of the ".args" file
specified in the previous argument.
Returns:
Tuple consisting of:
Expand All @@ -118,8 +121,12 @@ def parse_args_into_dataclasses(
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""
if look_for_args_file and len(sys.argv):
args_file = Path(sys.argv[0]).with_suffix(".args")
if args_filename or (look_for_args_file and len(sys.argv)):
if args_filename:
args_file = Path(args_filename)
else:
args_file = Path(sys.argv[0]).with_suffix(".args")

if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
Expand Down