Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand dynamic supported objects to configs and tokenizers #14296

Merged
merged 6 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,24 @@ def __init__(self, *args, **kwargs):

@classmethod
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
trust_remote_code = kwargs.pop("trust_remote_code", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
"Loading this model requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

Expand All @@ -394,7 +411,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
Expand Down
30 changes: 29 additions & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

from ...configuration_utils import PretrainedConfig
from ...file_utils import CONFIG_NAME
from ...utils import logging
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)

CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
Expand Down Expand Up @@ -523,6 +527,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs(additional keyword arguments, `optional`):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
Expand Down Expand Up @@ -555,8 +563,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
{'foo': False}
"""
kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict:
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
"ensure no malicious code has been contributed in a newer revision."
)
class_ref = config_dict["auto_map"]["AutoConfig"]
module_file, class_name = class_ref.split(".")
config_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
return config_class.from_dict(config_dict, **kwargs)
else:
Expand Down
38 changes: 36 additions & 2 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -412,6 +413,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
Expand All @@ -436,6 +441,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)

# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
Expand Down Expand Up @@ -464,17 +470,45 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = tokenizer_config.get("auto_map")

# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]

# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
if tokenizer_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)

if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]

module_file, class_name = class_ref.split(".")
tokenizer_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)

elif use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,7 @@ def _from_pretrained(
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("auto_map", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
Expand Down Expand Up @@ -2028,6 +2029,8 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class
if getattr(self, "_auto_map", None) is not None:
tokenizer_config["auto_map"] = self._auto_map

with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
Expand Down
43 changes: 41 additions & 2 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import tempfile
import unittest

from huggingface_hub import delete_repo, login
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config, is_torch_available
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test

Expand Down Expand Up @@ -190,6 +190,23 @@ def run_common_tests(self):
self.check_config_arguments_init()


class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)


# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig

class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
Expand All @@ -208,6 +225,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, name="test-dynamic-config")
except HTTPError:
pass

def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
Expand Down Expand Up @@ -238,6 +260,23 @@ def test_push_to_hub_in_organization(self):
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}

with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)

repo.push_to_hub()

new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.attribute, 42)


class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
Expand Down
74 changes: 72 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
import transformers
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers import (
AutoConfig,
AutoModel,
AutoModelForSequenceClassification,
PretrainedConfig,
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
Expand Down Expand Up @@ -67,7 +74,6 @@
AdaptiveEmbedding,
BertConfig,
BertModel,
PretrainedConfig,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
Expand Down Expand Up @@ -2078,6 +2084,23 @@ def test_model_from_pretrained_torch_dtype(self):
self.assertEqual(model.dtype, torch.float16)


class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)


# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig

class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""


if is_torch_available():

class FakeModel(PreTrainedModel):
Expand Down Expand Up @@ -2140,6 +2163,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, name="test-dynamic-model-config")
except HTTPError:
pass

def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
Expand Down Expand Up @@ -2185,5 +2213,47 @@ def test_push_to_hub_dynamic_model(self):
repo.push_to_hub()

new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")

def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)

with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)

repo.push_to_hub()

new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
self.assertEqual(new_model.config.attribute, 42)

# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
Loading