Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SigLIP] Add fast tokenizer #29969

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/siglip.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ If you're interested in submitting a resource to be included here, please feel f
- create_token_type_ids_from_sequences
- save_vocabulary

## SiglipTokenizerFast

[[autodoc]] SiglipTokenizerFast

## SiglipImageProcessor

[[autodoc]] SiglipImageProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@
_import_structure["models.roberta"].append("RobertaTokenizerFast")
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
_import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizerFast")
_import_structure["models.siglip"].append("SiglipTokenizerFast")
_import_structure["models.splinter"].append("SplinterTokenizerFast")
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
_import_structure["models.t5"].append("T5TokenizerFast")
Expand Down Expand Up @@ -6180,6 +6181,7 @@
from .models.roberta import RobertaTokenizerFast
from .models.roformer import RoFormerTokenizerFast
from .models.seamless_m4t import SeamlessM4TTokenizerFast
from .models.siglip import SiglipTokenizerFast
from .models.splinter import SplinterTokenizerFast
from .models.squeezebert import SqueezeBertTokenizerFast
from .models.t5 import T5TokenizerFast
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
allow to make our dependency on SentencePiece optional.
"""

import re
import string
import warnings
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -1059,6 +1061,41 @@ def post_processor(self):
)


class SiglipConverter(SpmConverter):
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

list_normalizers = []

if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
list_normalizers.append(normalizers.Replace(Regex(r"[" + re.escape(string.punctuation) + "]"), ""))
list_normalizers.extend(
[
normalizers.Replace(Regex(r"\s+"), " "),
normalizers.Strip(),
]
)

if not precompiled_charsmap:
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
else:
list_normalizers.extend(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)

return normalizers.Sequence(list_normalizers)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is, you are going to get also the MetaSpace pre-tokenizer, but I am guessing this is also wanted

def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)


class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
Expand Down Expand Up @@ -1498,6 +1535,7 @@ def converted(self) -> Tokenizer:
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SiglipTokenizer": SiglipConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,13 @@
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
(
"siglip",
(
"SiglipTokenizer" if is_sentencepiece_available() else None,
"SiglipTokenizerFast" if is_tokenizers_available() else None,
),
),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/siglip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
Expand All @@ -41,6 +42,15 @@
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]


try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_siglip_fast"] = ["SiglipTokenizerFast"]


try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -82,6 +92,14 @@
else:
from .tokenization_siglip import SiglipTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_siglip_fast import SiglipTokenizerFast

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/siglip/test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from transformers import SiglipTokenizer, SiglipTokenizerFast


slow_tokenizer = SiglipTokenizer.from_pretrained("google/siglip-so400m-patch14-384")

fast_tokenizer = SiglipTokenizerFast.from_pretrained("google/siglip-so400m-patch14-384")

text = "hello world"

inputs = slow_tokenizer(text, return_tensors="pt")

fast_inputs = fast_tokenizer(text, return_tensors="pt")

for k, v in inputs.items():
assert (v == fast_inputs[k]).all()
187 changes: 187 additions & 0 deletions src/transformers/models/siglip/tokenization_siglip_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fast tokenization class for SigLIP."""


import os
from shutil import copyfile
from typing import List, Optional, Tuple

from ...tokenization_utils_fast import AddedToken, PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging


if is_sentencepiece_available():
from .tokenization_siglip import SiglipTokenizer
else:
SiglipTokenizer = None


logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}


class SiglipTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" SigLIP tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).

This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.

Args:
vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`, *optional*):
Path to tokenizer file.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.

<Tip>

When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.

</Tip>

unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"</s>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
"""

vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids"]
slow_tokenizer_class = SiglipTokenizer

prefix_tokens: List[int] = []

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
pad_token="</s>",
additional_special_tokens=None,
**kwargs,
):
pad_token = (
AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(pad_token, str)
else pad_token
)
unk_token = (
AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(unk_token, str)
else unk_token
)
eos_token = (
AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
if isinstance(eos_token, str)
else eos_token
)

super().__init__(
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)

self.vocab_file = vocab_file

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of copied from are missing here as well

def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)

if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}")

return (out_vocab_file,)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:

- single sequence: `X </s>`
- pair of sequences: `A </s> B </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
token_ids_0 = token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
else:
token_ids_1 = token_ids_1 + [self.eos_token_id]
return self.prefix_tokens + token_ids_0 + token_ids_1

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]

if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]

def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tokenizers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class SiglipTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class SplinterTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

Expand Down
Loading
Loading