Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed May 14, 2024
1 parent c609619 commit 46fd0e4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 90 deletions.
200 changes: 117 additions & 83 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,13 +1484,6 @@ def all_special_ids(self) -> List[int]:
high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the
low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the
associated pretrained vocabulary file.
- **max_model_input_sizes** (`Dict[str, Optional[int]]`) -- A dictionary with, as keys, the `short-cut-names`
of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model,
or `None` if the model has no maximum input size.
- **pretrained_init_configuration** (`Dict[str, Dict[str, Any]]`) -- A dictionary with, as keys, the
`short-cut-names` of the pretrained models, and as associated values, a dictionary of specific arguments to
pass to the `__init__` method of the tokenizer class for this pretrained model when loading the tokenizer
with the [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`] method.
- **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model.
- **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
Should be `'right'` or `'left'`.
Expand Down Expand Up @@ -1561,8 +1554,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):

vocab_files_names: Dict[str, str] = {}
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
max_model_input_sizes: Dict[str, Optional[int]] = {}
_auto_class: Optional[str] = None

# first name has to correspond to main model input name
Expand Down Expand Up @@ -1610,6 +1601,10 @@ def __init__(self, **kwargs):

# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)
if isinstance(self.chat_template, (list, tuple)):
# Chat templates are stored as lists of dicts with fixed key names,
# we reconstruct that into a single dict while loading them.
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}

super().__init__(**kwargs)

Expand Down Expand Up @@ -1688,7 +1683,7 @@ def get_vocab(self) -> Dict[str, int]:

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
Expand All @@ -1697,16 +1692,17 @@ def apply_chat_template(
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
**tokenizer_kwargs,
) -> Union[str, List[int]]:
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
"""
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
to the default_chat_template specified at the class level.
Args:
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
Expand All @@ -1730,56 +1726,111 @@ def apply_chat_template(
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
return_dict (`bool`, *optional*, defaults to `False`):
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
Returns:
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`.
`Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
set, will return a dict of tokenizer outputs instead.
"""

if hasattr(conversation, "messages"):
# Indicates it's a Conversation object
conversation = conversation.messages
if return_dict and not tokenize:
raise ValueError(
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
"of tokenizer outputs to return."
)

if tokenizer_kwargs is None:
tokenizer_kwargs = {}

# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
if chat_template is None:
using_default_template = False

# First, handle the cases when the model has a dict of multiple templates
if isinstance(self.chat_template, dict) or (
self.chat_template is None and isinstance(self.default_chat_template, dict)
):
if self.chat_template is not None:
template_dict = self.chat_template
using_default_dict = False
else:
template_dict = self.default_chat_template
using_default_dict = True
if chat_template is not None and chat_template in template_dict:
# The user can pass the name of a template to the chat template argument instead of an entire template
chat_template = template_dict[chat_template]
if using_default_dict:
using_default_template = True
elif chat_template is None and "default" in template_dict:
chat_template = template_dict["default"]
if using_default_dict:
using_default_template = True
elif chat_template is None:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template_dict.keys())}."
)
elif chat_template is None:
# These are the cases when the model has a single template
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if self.chat_template is not None:
chat_template = self.chat_template
else:
chat_template = self.default_chat_template
using_default_template = True

if using_default_template:
logger.warning_once(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)

rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
):
conversations = conversation
is_batched = True
else:
conversations = [conversation]
is_batched = False

rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
)
rendered.append(rendered_chat)

if not is_batched:
rendered = rendered[0]

if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize:
out = self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
if return_dict:
return self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
return out
else:
return self.encode(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
return out["input_ids"]
else:
return rendered

Expand All @@ -1792,7 +1843,7 @@ def _compile_jinja_template(self, chat_template):
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")

if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
if version.parse(jinja2.__version__) < version.parse("3.0.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}."
)
Expand All @@ -1810,12 +1861,6 @@ def default_chat_template(self):
This template formats inputs in the standard ChatML format. See
https://github.com/openai/openai-python/blob/main/chatml.md
"""
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using a default chat template "
"that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
return (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
Expand Down Expand Up @@ -1859,9 +1904,9 @@ def from_pretrained(
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Attempt to resume the download if such a file
exists.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Expand Down Expand Up @@ -1916,7 +1961,7 @@ def from_pretrained(
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
assert tokenizer.unk_token == "<unk>"
```"""
resume_download = kwargs.pop("resume_download", False)
resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
subfolder = kwargs.pop("subfolder", None)
Expand Down Expand Up @@ -2185,23 +2230,6 @@ def _from_pretrained(
# Update with newly provided kwargs
init_kwargs.update(kwargs)

# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings

model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
if model_max_length is not None and isinstance(model_max_length, (int, float)):
model_max_length = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
# TODO(PVP) - uncomment following line in Transformers v5
# init_kwargs["model_max_length"] = model_max_length
# TODO(PVP) - remove in Transformers v5
# ---
init_kwargs["model_max_length"] = cls._eventually_correct_t5_max_length(
pretrained_model_name_or_path, model_max_length, init_kwargs.get("model_max_length")
)
# ---

# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
Expand Down Expand Up @@ -2426,7 +2454,12 @@ def save_pretrained(
tokenizer_config.update(self.special_tokens_map)

if self.chat_template is not None:
tokenizer_config["chat_template"] = self.chat_template
if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
# They will be reconstructed as a single dict during loading.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
else:
tokenizer_config["chat_template"] = self.chat_template

if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
Expand Down Expand Up @@ -2872,7 +2905,7 @@ def _call_one(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
# Input type checking for clearer error
Expand Down Expand Up @@ -3061,7 +3094,7 @@ def _encode_plus(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
raise NotImplementedError
Expand Down Expand Up @@ -3092,7 +3125,7 @@ def batch_encode_plus(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
"""
Expand Down Expand Up @@ -3137,7 +3170,8 @@ def batch_encode_plus(
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,split_special_tokens=split_special_tokens,
verbose=verbose,
split_special_tokens=split_special_tokens,
**kwargs,
)

Expand Down Expand Up @@ -3166,7 +3200,7 @@ def _batch_encode_plus(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
raise NotImplementedError
Expand Down Expand Up @@ -3624,7 +3658,7 @@ def truncate_sequences(
ids = ids[ids_to_move:]
pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
raise ValueError(f"invalid truncation strategy:{self.truncation_side}")

elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
Expand All @@ -3636,7 +3670,7 @@ def truncate_sequences(
overflowing_tokens = pair_ids[:window_len]
pair_ids = pair_ids[num_tokens_to_remove:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
raise ValueError(f"invalid truncation strategy:{self.truncation_side}")
else:
logger.error(
f"We need to remove {num_tokens_to_remove} to truncate the input "
Expand Down Expand Up @@ -3720,7 +3754,7 @@ def _pad(
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
raise ValueError(f"Invalid padding strategy:{self.padding_side}")

return encoded_inputs

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _batch_encode_plus(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
) -> BatchEncoding:
if not isinstance(batch_text_or_text_pairs, (tuple, list)):
raise TypeError(
Expand Down Expand Up @@ -574,7 +574,7 @@ def _encode_plus(
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens:bool = False,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
batched_input = [(text, text_pair)] if text_pair else [text]
Expand Down
Loading

0 comments on commit 46fd0e4

Please sign in to comment.