Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SigLIP] Add fast tokenizer #29969

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d5d67b7
First draft
NielsRogge Mar 30, 2024
cbde88a
Fix more tests
NielsRogge Mar 31, 2024
de444e9
Add test
NielsRogge Mar 31, 2024
009fdc6
Remove print statements
NielsRogge Apr 1, 2024
f714af0
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Apr 22, 2024
6cd05c2
Address comments
NielsRogge Apr 22, 2024
d67e40f
Use regex
NielsRogge Apr 22, 2024
f576078
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Aug 22, 2024
de04050
Rebase
NielsRogge Aug 22, 2024
844c95c
Fix more tests
NielsRogge Mar 31, 2024
50500a5
remove strip in tokenize, keep characters used in special tokens, fix…
itazap Aug 23, 2024
8ba6e0b
ruff and FRAMEWORK error fix
itazap Aug 24, 2024
bf4f6db
remove unnecessary assertNotEqual from t5 (and siglip), add copied from)
itazap Aug 26, 2024
d850451
rm copied from
itazap Aug 26, 2024
e73fa01
typo
itazap Aug 26, 2024
cbe0a31
removing fast class
itazap Sep 24, 2024
d2b2339
updated tests for fast
itazap Sep 30, 2024
6379f9d
remove dev test file
Sep 30, 2024
6f55733
Merge branch 'main' into add_siglip_fast_tokenizer_bis
itazap Sep 30, 2024
05f8b5c
Update src/transformers/models/auto/tokenization_auto.py
ArthurZucker Oct 1, 2024
e296021
Update tests/models/llama/test_tokenization_llama.py
itazap Oct 1, 2024
0f9669b
Update src/transformers/models/siglip/__init__.py
itazap Oct 1, 2024
0bca141
Update src/transformers/models/siglip/__init__.py
itazap Oct 2, 2024
80d2f46
add auto test
Oct 2, 2024
2133809
fix test not to try importing Sigliptokenizerfast
Oct 2, 2024
3a0e825
import pretrained instead of siglip
Oct 3, 2024
e975d96
rm llama change
Oct 18, 2024
dc8ed15
Merge remote-tracking branch 'upstream/main' into add_siglip_fast_tok…
NielsRogge Oct 21, 2024
dfada5a
Make fixup
NielsRogge Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/siglip.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ If you're interested in submitting a resource to be included here, please feel f
- create_token_type_ids_from_sequences
- save_vocabulary

## SiglipTokenizerFast

[[autodoc]] SiglipTokenizerFast

## SiglipImageProcessor

[[autodoc]] SiglipImageProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,7 @@
_import_structure["models.roberta"].append("RobertaTokenizerFast")
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
_import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizerFast")
_import_structure["models.siglip"].append("SiglipTokenizerFast")
_import_structure["models.splinter"].append("SplinterTokenizerFast")
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
_import_structure["models.t5"].append("T5TokenizerFast")
Expand Down Expand Up @@ -6109,6 +6110,7 @@
from .models.roberta import RobertaTokenizerFast
from .models.roformer import RoFormerTokenizerFast
from .models.seamless_m4t import SeamlessM4TTokenizerFast
from .models.siglip import SiglipTokenizerFast
from .models.splinter import SplinterTokenizerFast
from .models.squeezebert import SqueezeBertTokenizerFast
from .models.t5 import T5TokenizerFast
Expand Down
37 changes: 37 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
allow to make our dependency on SentencePiece optional.
"""

import string
import warnings
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -1050,6 +1051,41 @@ def post_processor(self):
)


class SiglipConverter(SpmConverter):
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

list_normalizers = []

if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
list_normalizers.extend([normalizers.Replace(i, "") for i in string.punctuation])
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
list_normalizers.extend(
[
normalizers.Replace(Regex(r"\s+"), " "),
normalizers.Strip(),
]
)

if not precompiled_charsmap:
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
else:
list_normalizers.extend(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)

return normalizers.Sequence(list_normalizers)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)


class WhisperConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
Expand Down Expand Up @@ -1489,6 +1525,7 @@ def converted(self) -> Tokenizer:
"WhisperTokenizer": WhisperConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
"SiglipTokenizer": SiglipConverter,
"SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,13 @@
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
(
"siglip",
(
"SiglipTokenizer" if is_sentencepiece_available() else None,
"SiglipTokenizerFast" if is_tokenizers_available() else None,
),
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/siglip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
Expand All @@ -41,6 +42,15 @@
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]


try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_siglip_fast"] = ["SiglipTokenizerFast"]


try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -82,6 +92,14 @@
else:
from .tokenization_siglip import SiglipTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
itazap marked this conversation as resolved.
Show resolved Hide resolved
else:
from .tokenization_siglip_fast import SiglipTokenizerFast

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/siglip/test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from transformers import SiglipTokenizer, SiglipTokenizerFast


slow_tokenizer = SiglipTokenizer.from_pretrained("google/siglip-so400m-patch14-384")

fast_tokenizer = SiglipTokenizerFast.from_pretrained("google/siglip-so400m-patch14-384")

text = "hello world"

inputs = slow_tokenizer(text, return_tensors="pt")

fast_inputs = fast_tokenizer(text, return_tensors="pt")

for k, v in inputs.items():
assert (v == fast_inputs[k]).all()
171 changes: 171 additions & 0 deletions src/transformers/models/siglip/tokenization_siglip_fast.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we even need that class no? We could just use a PreTrainedTokenizerFast

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't consider we could do that! Is there an example I can reference? Not sure what to do with the functions copied over from T5 here.

Also, looking more into the functions here, would it be better to move common functions like save_vocabulary (duplicated in 15 _fast files) to PreTrainedTokenizerFast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, wouldn't that confuse users? if there's no dedicated class with the name of the model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NielsRogge yes I see your point I also had the same thought. Do we have other fast models we support without a dedicated class?@ArthurZucker

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have quite a lot of other models that just use PreTrainedTokenizerFast, including Llama (llama v3) , all mambas etc. Tokenizers are more prone to change than models (you could have Mamba with LlamaTokenizer) so it makes more sense to deprecate slow and model-specific ones

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need them they are build on the layer of PreTrainedTokenizerFast + we can embed stuff inside the tokenizer fast itelsf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Sorry, not really understanding if you mean to just remove this file entirely and not worry about the functions? (or do we need to embed them somewhere?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah just remove it entirely

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker and would we have to add a tokenizer.json to the hub?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fast tokenization class for SigLIP."""


import os
from shutil import copyfile
from typing import List, Optional, Tuple

from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging


if is_sentencepiece_available():
from .tokenization_siglip import SiglipTokenizer
else:
SiglipTokenizer = None


logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}


class SiglipTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" SigLIP tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).

This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.

Args:
vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`, *optional*):
Path to tokenizer file.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.

<Tip>

When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.

</Tip>

unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"</s>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
"""

vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids"]
slow_tokenizer_class = SiglipTokenizer

prefix_tokens: List[int] = []

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
eos_token="</s>",
unk_token="<unk>",
pad_token="</s>",
additional_special_tokens=None,
**kwargs,
):
super().__init__(
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
vocab_file,
tokenizer_file=tokenizer_file,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)

self.vocab_file = vocab_file

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of copied from are missing here as well

def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)

if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}")

return (out_vocab_file,)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:

- single sequence: `X </s>`
- pair of sequences: `A </s> B </s>`

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
token_ids_0 = token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
else:
token_ids_1 = token_ids_1 + [self.eos_token_id]
return self.prefix_tokens + token_ids_0 + token_ids_1

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]

if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]

def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tokenizers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class SiglipTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class SplinterTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

Expand Down
15 changes: 8 additions & 7 deletions tests/models/siglip/test_tokenization_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import tempfile
import unittest

from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer, SiglipTokenizerFast
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available
from transformers.utils import cached_property, is_torch_available

from ...test_tokenization_common import TokenizerTesterMixin

Expand All @@ -29,18 +29,15 @@

if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"


@require_sentencepiece
@require_tokenizers
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/siglip-base-patch16-224"
tokenizer_class = SiglipTokenizer
test_rust_tokenizer = False
rust_tokenizer_class = SiglipTokenizerFast
test_rust_tokenizer = True
test_sentencepiece = True
test_sentencepiece_ignore_case = True

Expand Down Expand Up @@ -139,6 +136,10 @@ def siglip_tokenizer(self):
def get_tokenizer(self, **kwargs) -> SiglipTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)

# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.get_rust_tokenizer with T5->Siglip
def get_rust_tokenizer(self, **kwargs) -> SiglipTokenizerFast:
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_rust_and_python_full_tokenizers with T5->Siglip
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
Expand Down
Loading
Loading