From 1cc453d33c5d0be01eaf3050082c125ce87491aa Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 15 Nov 2021 16:38:02 -0500 Subject: [PATCH] Allow per-version configurations (#14344) * Allow per-version configurations * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 76 +++++++++++++++++++++++-- tests/test_configuration_common.py | 39 +++++++++++++ 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 56789dc3f2a9..3ed6f65b9cda 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -19,8 +19,11 @@ import copy import json import os +import re import warnings -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union + +from packaging import version from . import __version__ from .file_utils import ( @@ -28,6 +31,7 @@ PushToHubMixin, cached_path, copy_func, + get_list_of_files, hf_bucket_url, is_offline_mode, is_remote_url, @@ -37,6 +41,8 @@ logger = logging.get_logger(__name__) +FULL_CONFIGURATION_FILE = "config.json" +_re_configuration_file = re.compile(r"config\.(.*)\.json") class PretrainedConfig(PushToHubMixin): @@ -536,15 +542,23 @@ def get_config_dict( local_files_only = True pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: - config_file = hf_bucket_url( - pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None + configuration_file = get_configuration_file( + pretrained_model_name_or_path, + revision=revision, + use_auth_token=use_auth_token, + local_files_only=local_files_only, ) + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, configuration_file) + else: + config_file = hf_bucket_url( + pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None + ) + try: # Load from URL or cache if already cached resolved_config_file = cached_path( @@ -796,6 +810,56 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] +def get_configuration_file( + path_or_repo: Union[str, os.PathLike], + revision: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, +) -> str: + """ + Get the configuration file to use for this version of transformers. + + Args: + path_or_repo (:obj:`str` or :obj:`os.PathLike`): + Can be either the id of a repo on huggingface.co or a path to a `directory`. + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. + + Returns: + :obj:`str`: The configuration file to use. + """ + # Inspect all files from the repo/folder. + all_files = get_list_of_files( + path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only + ) + configuration_files_map = {} + for file_name in all_files: + search = _re_configuration_file.search(file_name) + if search is not None: + v = search.groups()[0] + configuration_files_map[v] = file_name + available_versions = sorted(configuration_files_map.keys()) + + # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. + configuration_file = FULL_CONFIGURATION_FILE + transformers_version = version.parse(__version__) + for v in available_versions: + if version.parse(v) <= transformers_version: + configuration_file = configuration_files_map[v] + else: + # No point going further since the versions are sorted. + break + + return configuration_file + + PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub) PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( object="config", object_class="AutoConfig", object_files="configuration file" diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 66c7652f3996..78672d44f0b6 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -16,8 +16,10 @@ import copy import json import os +import shutil import tempfile import unittest +import unittest.mock from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError @@ -306,3 +308,40 @@ def test_config_common_kwargs_is_complete(self): "The following keys are set with the default values in `test_configuration_common.config_common_kwargs` " f"pick another value for them: {', '.join(keys_with_defaults)}." ) + + +class ConfigurationVersioningTest(unittest.TestCase): + def test_local_versioning(self): + configuration = AutoConfig.from_pretrained("bert-base-cased") + + with tempfile.TemporaryDirectory() as tmp_dir: + configuration.save_pretrained(tmp_dir) + configuration.hidden_size = 2 + json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w")) + + # This should pick the new configuration file as the version of Transformers is > 4.0.0 + new_configuration = AutoConfig.from_pretrained(tmp_dir) + self.assertEqual(new_configuration.hidden_size, 2) + + # Will need to be adjusted if we reach v42 and this test is still here. + # Should pick the old configuration file as the version of Transformers is < 4.42.0 + shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json")) + new_configuration = AutoConfig.from_pretrained(tmp_dir) + self.assertEqual(new_configuration.hidden_size, 768) + + def test_repo_versioning_before(self): + # This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower. + repo = "microsoft/layoutxlm-base" + + import transformers as new_transformers + + new_transformers.configuration_utils.__version__ = "v5.0.0" + new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo) + self.assertEqual(new_configuration.tokenizer_class, None) + + # Testing an older version by monkey-patching the version in the module it's used. + import transformers as old_transformers + + old_transformers.configuration_utils.__version__ = "v3.0.0" + old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) + self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer")