Skip to content

Commit

Permalink
Add tokenizer and test
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Nov 5, 2021
1 parent 6608efa commit 99d76d0
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...utils import logging
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)

CONFIG_MAPPING_NAMES = OrderedDict(
Expand Down Expand Up @@ -567,7 +568,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
Expand Down
38 changes: 36 additions & 2 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from .dynamic import get_class_from_dynamic_module


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -412,6 +413,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
Expand All @@ -436,6 +441,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)

# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
Expand Down Expand Up @@ -464,17 +470,45 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = tokenizer_config.get("auto_map")

# If that did not work, let's try to use the config.
if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]

# If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"):
if tokenizer_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)

if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]

module_file, class_name = class_ref.split(".")
tokenizer_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)

elif use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,7 @@ def _from_pretrained(
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("auto_map", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
Expand Down Expand Up @@ -2028,6 +2029,8 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class
if getattr(self, "_auto_map", None) is not None:
tokenizer_config["auto_map"] = self._auto_map

with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def run_common_tests(self):
self.check_config_arguments_init()



class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
Expand All @@ -207,6 +206,7 @@ def __init__(self, attribute=1, **kwargs):
super().__init__(**kwargs)
"""


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_push_to_hub_in_organization(self):
for k, v in config.__dict__.items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
Expand Down
10 changes: 7 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
AdaptiveEmbedding,
BertConfig,
BertModel,
PretrainedConfig,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
Expand Down Expand Up @@ -2211,10 +2210,15 @@ def test_push_to_hub_dynamic_model(self):
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))

def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42, vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
Expand Down
73 changes: 72 additions & 1 deletion tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
from itertools import takewhile
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union

from huggingface_hub import delete_repo, login
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import (
AlbertTokenizer,
AlbertTokenizerFast,
AutoTokenizer,
BertTokenizer,
BertTokenizerFast,
PreTrainedTokenizer,
Expand All @@ -41,6 +42,7 @@
Trainer,
TrainingArguments,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
from transformers.testing_utils import (
Expand Down Expand Up @@ -3513,6 +3515,28 @@ def test_saving_tokenizer_trainer(self):
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))


class FakeTokenizer(BertTokenizer):
pass


if is_tokenizers_available():

class FakeTokenizerFast(BertTokenizerFast):
pass


# Make sure this is synchronized with the tokenizers above.
FAKE_TOKENIZER_CODE = """
from transformers import BertTokenizer, BertTokenizerFast
class FakeTokenizer(BertTokenizer):
pass
class FakeTokenizerFast(BertTokenizerFast):
pass
"""


@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
Expand All @@ -3533,6 +3557,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, name="test-dynamic-tokenizer")
except HTTPError:
pass

def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
Expand Down Expand Up @@ -3562,6 +3591,48 @@ def test_push_to_hub_in_organization(self):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)

def test_push_to_hub_dynamic_tokenizer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = FakeTokenizer(vocab_file)

# No fast custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", None)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)

repo.push_to_hub()

tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")

# Fast and slow custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast")
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)

repo.push_to_hub()

tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")


class TrieTest(unittest.TestCase):
def test_trie(self):
Expand Down

0 comments on commit 99d76d0

Please sign in to comment.