diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index fb0dffe294..d857ec5f63 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -1,6 +1,6 @@ import math import numbers -from typing import List, Optional, Sequence, Tuple, Union, cast +from typing import Callable, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -10,6 +10,11 @@ from captum.optim._utils.typing import IntSeqOrIntType, NumSeqOrTensorOrProbDistType from packaging import version +try: + from torchtext.transforms import CLIPTokenizer as CLIPTokenizer_TorchText +except ImportError: + print("torchtext >=0.12.0 is required to use Captum's Optim CLIPTokenizer") + class BlendAlpha(nn.Module): r"""Blends a 4 channel input parameterization into an RGB image. @@ -1304,6 +1309,289 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class CLIPTokenizer(torch.nn.Module): + """ + This module allows individuals to use torchtext's CLIP tokenizer with a wrapper + that handles context_length padding, special start and end tokens, truncation, and + to tensor conversions. This module also supports JIT, and can decode tokens. + + Note that this module does not implement preprocessing like whitespace cleaning, + HTML to unicode conversions, or heuristic unicode correction. + + Example:: + + >>> clip_tokenizer = opt.transforms.CLIPTokenizer(pretrained_merges=True) + >>> tokens = clip_tokenizer("An example sentence.") + >>> print(tokens[0][:6]) + tensor([49406, 550, 6228, 12737, 269, 49407], dtype=torch.int32) + >>> decoded_str = clip_tokenizer.decode(tokens) + >>> print(decoded_str) + ['an example sentence .'] + + See here for more details: + https://pytorch.org/text/main/transforms.html#torchtext.transforms.CLIPTokenizer + + The torchtext CLIPTokenizer is based on these implementations: + https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + https://github.com/mlfoundations/open_clip/blob/main/src/clip/tokenizer.py + """ + + __constants__ = [ + "context_length", + "start_token", + "end_token", + "_merges_path", + "_num_merges", + "padding_value", + "truncate", + ] + + def __init__( + self, + merges_path: Optional[str] = None, + context_length: int = 77, + start_token: Optional[str] = "<|startoftext|>", + end_token: Optional[str] = "<|endoftext|>", + pretrained_merges: bool = True, + num_merges: Optional[int] = None, + padding_value: int = 0, + truncate: bool = False, + preprocessing_module: Optional[Callable[[List[str]], List[str]]] = None, + ) -> None: + """ + Args: + + merges_path (str, optional): Path to file containing the merges, or where + to save the merges file if pretrained_merges is set to ``True``. The + :func:`torch.hub.get_dir()` function will be used to get the directory + if set to ``None``, resulting in a path of: /vocab. + Default: ``None`` + context_length (int, optional): The required context length for the model. + Inputs with lengths less than ``context_length`` will be padded with + zeros. + Default: ``77`` + start_token (str, optional): The starting token to place in front of each + text input. Set to ``None`` for no start token. + Default: ``"<|startoftext|>"`` + end_token (str, optional): The ending token to place at the end of each + text input. Set to ``None`` for no end token. + Default: ``"<|endoftext|>"`` + pretrained_merges (bool, optional): Whether or not to download merges for + the pretrained CLIP model. + Default: ``True`` + num_merges (int, optional): The number of lines to use from the merges + file. Set to None for all lines. + Default: ``None`` + padding_value (int, optional): An integer value to use for padding token + sets to the desired ``context_length``. + Default: ``0`` + truncate (bool, optional): Whether or not to truncate outputs larger than + ``context_length``. + Default: ``False`` + preprocessing_module (callable, optional): An optional function that takes + a list of str and returns a list of str. This can be used to implement + whitespace cleaning, HTML to unicode conversions, or heuristic unicode + correction. Set to ``None`` for no text str preprocessing. + Default: ``None`` + """ + super().__init__() + self.context_length = context_length + self.start_token = start_token + self.end_token = end_token + + if pretrained_merges: + merges_path = self._download_clip_bpe_merges(file_dir=merges_path) + else: + assert merges_path is not None + + self._num_merges = num_merges + self._merges_path = merges_path + self.clip_tokenizer_module = CLIPTokenizer_TorchText( + merges_path=merges_path, num_merges=num_merges + ) + self.padding_value = padding_value + self.truncate = truncate + self.preprocessing_module = preprocessing_module + + @staticmethod + @torch.jit.ignore + def _download_clip_bpe_merges(file_dir: Optional[str] = None) -> str: + """ + Download a copy of CLIP's BPE merges for the first 48895 lines of the + 'bpe_simple_vocab_16e6.txt.gz' file from: https://github.com/openai/CLIP. + + The BPE merges file will not be downloaded if it already exists in the + specified directory. + + Args: + + file_dir (str, optional): Optionally provide a location to save the + file to. The :func:`torch.hub.get_dir()` function will be used to get + the directory if set to None, resulting in a path + of: /vocab. + Default: ``None`` + + Returns: + filename (str): The path to the downloaded file with the filename. + """ + from os import path, makedirs + + import requests + + url = ( + "https://pytorch.s3.amazonaws.com/models/captum/" + + "clip_bpe_simple_vocab_48895.txt" + ) + if file_dir is None: + file_dir = path.join(torch.hub.get_dir(), "vocab") + else: + assert path.splitext(path.basename(file_dir))[1] == "" + + filename = path.join(file_dir, path.basename(url)) + + # Create dir if it doesn't exist + if file_dir != "" and not path.isdir(file_dir): + makedirs(file_dir) + + if not path.isfile(filename): + print("Downloading: '{}' to '{}'\n".format(path.basename(url), file_dir)) + file = requests.get(url) + with open(filename, "wb") as f: + f.write(file.content) + return filename + + @torch.jit.ignore + def decode( + self, + x: Union[torch.Tensor, List[int], List[List[int]]], + include_special_tokens: bool = False, + ) -> List[List[str]]: + """ + Decode token values into their corresponding string values. + + Based on the implementations used by OpenAI & TorchText: + https://github.com/openai/gpt-2/blob/master/src/encoder.py + https://github.com/pytorch/text/blob/main/torchtext/transforms.py + + Args: + + x (torch.Tensor, list of int, or list of list of int): A set of tokens + stacked across the batch dimension, a list of tokens, or a list of + lists of tokens. + include_special_tokens (bool, optional): Whether or not to included added + special tokens in the output. + Default: ``False`` + + Returns: + token_str (list of list of str): A set of strings that correspond to the + token values in the input. + """ + if isinstance(x, torch.Tensor): + x = x.unsqueeze(0) if x.dim() == 1 else x + assert x.dim() == 2 + x = [[t.tolist() for t in b] for b in x] + elif isinstance(x, (tuple, list)): + if any([isinstance(v, (tuple, list)) for v in x]): + assert all([[isinstance(t, int) for t in ts] for ts in x]) + else: + assert all([isinstance(t, int) for t in x]) + x = [x] + + with open(self._merges_path, "r", encoding="utf-8") as f: + bpe_merges = f.read().split("\n")[1:] + num_merges = self._num_merges or len(bpe_merges) + + # Setup vocab Unicode values + # Unicode values from "!" to "~", "¡" to "¬", "®" to "ÿ" + # Lowercase & uppercase are treated as the same character + bpe_v = list(range(33, 127)) + list(range(161, 173)) + list(range(174, 256)) + bpe_keys = bpe_v + list(range(0, 33)) + list(range(127, 161)) + [173] + bpe_vocab = [chr(c) for c in bpe_v + [256 + n for n in list(range(0, 68))]] + byte_decoder = dict(zip(bpe_vocab, bpe_keys)) + + bpe_vocab += [v + "" for v in bpe_vocab] + # Add vocab merges from file + bpe_vocab += [ + "".join(merge_pair.split()) for merge_pair in bpe_merges[:num_merges] + ] + + # Handle special tokens + if self.start_token != "": + bpe_vocab += [self.start_token] + if self.end_token != "": + bpe_vocab += [self.end_token] + + decoder = dict(zip(range(len(bpe_vocab)), bpe_vocab)) + + # Decode tokens + x = [[i for i in b if i != self.padding_value] for b in x] + token_str = ["".join([decoder[t] for t in ts]) for ts in x] + token_str = [bytearray([byte_decoder[t] for t in ts]) for ts in token_str] + token_str = [ + ts.decode("utf-8", errors="replace").replace("", " ").strip() + for ts in token_str + ] + if self.start_token and not include_special_tokens: + token_str = [s.replace(self.start_token, "") for s in token_str] + if self.end_token and not include_special_tokens: + token_str = [s.replace(self.end_token, "") for s in token_str] + return [s.strip() for s in token_str] + + def forward(self, x: Union[str, List[str]]) -> torch.Tensor: + """ + Args: + + x (str or list of str): Text values to be converted to tokenized tensors. + + Returns: + tokens (torch.Tensor): A tensor containing each set of tokens stacked + across the batch dimension. + """ + x = [x] if isinstance(x, str) else x + + if self.preprocessing_module is not None: + x = self.preprocessing_module(x) + + # Optionally add start & end tokens to inputs + if self.start_token: + x = [self.start_token + " " + s for s in x] + if self.end_token: + x = [s + " " + self.end_token for s in x] + + # Tokenize the text strings + tokens = self.clip_tokenizer_module(x) + + # Refine 'tokens' Type from Any to List[List[str]] in JIT + assert torch.jit.isinstance(tokens, List[List[str]]) + + # Optionally truncate inputs + if self.truncate: + if self.end_token: + tokens = [ + token_set[: self.context_length - 1] + [token_set[-1]] + if len(token_set) > self.context_length + else token_set + for token_set in tokens + ] + else: + tokens = [ + token_set[: self.context_length] + if len(token_set) > self.context_length + else token_set + for token_set in tokens + ] + + assert all([len(t) <= self.context_length for t in tokens]) + + # Convert str tokens to tensor values & apply zeros padding + p = self.padding_value + tokens = [ + [int(t) for t in token_set] + ([p] * (self.context_length - len(token_set))) + for token_set in tokens + ] + return torch.as_tensor(tokens).int() + + __all__ = [ "BlendAlpha", "IgnoreAlpha", @@ -1321,4 +1609,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "NChannelsToRGB", "RandomCrop", "TransformationRobustness", + "CLIPTokenizer", ] diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 385006a7ac..1164016d5f 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import unittest +from os import path from typing import List import captum.optim._param.image.transforms as transforms @@ -10,6 +11,13 @@ from tests.helpers.basic import BaseTest, assertTensorAlmostEqual from tests.optim.helpers import numpy_transforms +try: + from torchtext.transforms import CLIPTokenizer as CLIPTokenizer_TorchText + + _torchtext_has_clip_tokenizer = True +except ImportError: + _torchtext_has_clip_tokenizer = False + class TestRandomScale(BaseTest): def test_random_scale_init(self) -> None: @@ -2008,3 +2016,402 @@ def test_transform_robustness_forward_padding_crop_output_jit_module(self) -> No test_input = torch.ones(1, 3, 224, 224) test_output = transform_robustness(test_input) self.assertEqual(test_output.shape, test_input.shape) + + +class TestCLIPTokenizer(BaseTest): + def test_clip_tokenizer_pretrained_download(self) -> None: + file_path = path.join( + torch.hub.get_dir(), "vocab", "clip_bpe_simple_vocab_48895.txt" + ) + merges_path = transforms.CLIPTokenizer._download_clip_bpe_merges(None) + self.assertEqual(merges_path, file_path) + + def test_clip_tokenizer_pretrained_download_custom_path(self) -> None: + custom_path = path.join(torch.hub.get_dir(), "vocab_test") + file_path = path.join(custom_path, "clip_bpe_simple_vocab_48895.txt") + merges_path = transforms.CLIPTokenizer._download_clip_bpe_merges(custom_path) + self.assertEqual(merges_path, file_path) + + def test_clip_tokenizer_pretrained_download_assert_error(self) -> None: + file_path = path.join("vocab", "clip_bpe_simple_vocab_48895.txt") + with self.assertRaises(AssertionError): + _ = transforms.CLIPTokenizer._download_clip_bpe_merges(file_path) + + def test_clip_tokenizer_init(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer init test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + + self.assertEqual(clip_tokenizer.context_length, 77) + self.assertEqual(clip_tokenizer.start_token, "<|startoftext|>") + self.assertEqual(clip_tokenizer.end_token, "<|endoftext|>") + self.assertIsNone(clip_tokenizer._num_merges) + self.assertEqual(clip_tokenizer.padding_value, 0) + self.assertFalse(clip_tokenizer.truncate) + self.assertIsNone(clip_tokenizer.preprocessing_module) + + file_path = path.join( + torch.hub.get_dir(), "vocab", "clip_bpe_simple_vocab_48895.txt" + ) + self.assertEqual(clip_tokenizer._merges_path, file_path) + self.assertIsInstance( + clip_tokenizer.clip_tokenizer_module, CLIPTokenizer_TorchText + ) + + def test_clip_tokenizer_str_input(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + text_input_str = "this is a test!" + + text_output = clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [49406, 589, 533, 320, 1628, 256, 49407] + padding = [0] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + def test_clip_tokenizer_str_input_context_length_54(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input" + + " context_length test" + ) + context_length = 54 + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, context_length=context_length + ) + text_input_str = "this is a test!" + + text_output = clip_tokenizer(text_input_str) + + token_ids = [49406, 589, 533, 320, 1628, 256, 49407] + padding = [0] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + def test_clip_tokenizer_str_input_context_length_padding(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input test" + ) + padding_value = -1 + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, padding_value=padding_value + ) + text_input_str = "this is a test!" + + text_output = clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [49406, 589, 533, 320, 1628, 256, 49407] + padding = [padding_value] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + def test_clip_tokenizer_list_str_input(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer list str input" + + " test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + text_input_str = ["this is a test!", "a picture of a cat."] + + text_output = clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [ + [49406, 589, 533, 320, 1628, 256, 49407], + [49406, 320, 1674, 539, 320, 2368, 269, 49407], + ] + + self.assertEqual(list(text_output.shape), [2, context_length]) + for b, t in enumerate(token_ids): + padding = [0] * (context_length - len(t)) + token_set = t + padding + self.assertEqual(text_output[b].tolist(), token_set) + + def test_clip_tokenizer_str_input_decode(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input" + + " decode test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + text_input_str = "This is a test!" + + text_output = clip_tokenizer(text_input_str) + text_output_str = clip_tokenizer.decode(text_output) + + expected_ouput_str = ["this is a test !"] + self.assertEqual(text_output_str, expected_ouput_str) + + def test_clip_tokenizer_str_input_decode_special_tokens(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input" + + " decode include_special_tokens test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + text_input_str = "This is a test!" + + text_output = clip_tokenizer(text_input_str) + text_output_str = clip_tokenizer.decode( + text_output, include_special_tokens=True + ) + + expected_ouput_str = ["<|startoftext|>this is a test ! <|endoftext|>"] + self.assertEqual(text_output_str, expected_ouput_str) + + def test_clip_tokenizer_list_decode(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer list decode" + + " test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + + token_list = [49406, 589, 533, 320, 1628, 256, 49407, 0] + + str_output = clip_tokenizer.decode(token_list) + expected_str = ["this is a test !"] + self.assertEqual(str_output, expected_str) + + def test_clip_tokenizer_list_of_list_decode(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer list of list" + + " decode test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + + token_list = [ + [49406, 589, 533, 320, 1628, 256, 49407], + [49406, 320, 1674, 539, 320, 2368, 269, 49407, 0, 0], + ] + + str_output = clip_tokenizer.decode(token_list) + expected_str = ["this is a test !", "a picture of a cat ."] + self.assertEqual(str_output, expected_str) + + def test_clip_tokenizer_no_special_tokens(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer no special" + + " tokens test" + ) + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, start_token=None, end_token=None + ) + text_input_str = "This is a test!" + + text_output = clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [589, 533, 320, 1628, 256] + padding = [0] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + text_output_str = clip_tokenizer.decode( + text_output, include_special_tokens=True + ) + + expected_ouput_str = ["this is a test !"] + self.assertEqual(text_output_str, expected_ouput_str) + + def test_clip_tokenizer_pretrained_merges_false(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer pretrained" + + " merges False test" + ) + merges_path = transforms.CLIPTokenizer._download_clip_bpe_merges(None) + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=False, merges_path=merges_path + ) + text_input_str = "This is a test!" + + text_output = clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [49406, 589, 533, 320, 1628, 256, 49407] + padding = [0] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + text_output_str = clip_tokenizer.decode(text_output) + + expected_ouput_str = ["this is a test !"] + self.assertEqual(text_output_str, expected_ouput_str) + + def test_clip_tokenizer_str_input_jit(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer str input JIT" + + " test" + ) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + text_input_str = "this is a test!" + + jit_clip_tokenizer = torch.jit.script(clip_tokenizer) + text_output = jit_clip_tokenizer(text_input_str) + + context_length = 77 + token_ids = [49406, 589, 533, 320, 1628, 256, 49407] + padding = [0] * (context_length - len(token_ids)) + + token_set = token_ids + padding + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), token_set) + + def test_clip_tokenizer_unicode_encode(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer unicode test" + ) + + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, context_length=376 + ) + + bpe_v = list(range(33, 127)) + list(range(161, 173)) + list(range(174, 256)) + bpe_vocab = [chr(c) for c in bpe_v + [256 + n for n in list(range(0, 68))]] + bpe_vocab_str = " ".join(bpe_vocab) + txt_output = clip_tokenizer(bpe_vocab_str) + + # fmt: off + expected_tokens = [ + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, + 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, + 286, 287, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, + 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, + 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, + 346, 347, 348, 349, 10830, 41359, 1950, 126, 353, 20199, 126, 355, 126, + 356, 126, 357, 5811, 126, 359, 14434, 126, 361, 8436, 43593, 6858, 126, + 365, 41175, 126, 367, 12559, 126, 369, 126, 370, 14844, 126, 372, 126, + 373, 28059, 7599, 126, 376, 33613, 126, 378, 17133, 21259, 22229, 127, + 351, 127, 352, 47276, 127, 354, 127, 355, 127, 356, 37761, 4166, 127, 359, + 40520, 127, 361, 23928, 127, 362, 127, 363, 127, 364, 127, 365, 127, 366, + 27733, 127, 368, 127, 369, 37423, 16165, 45598, 127, 373, 36019, 127, 375, + 47177, 127, 377, 127, 378, 127, 509, 21259, 22229, 127, 351, 127, 352, + 47276, 127, 354, 127, 355, 127, 356, 37761, 4166, 127, 359, 40520, 127, + 361, 23928, 127, 362, 127, 363, 127, 364, 127, 365, 127, 366, 27733, 127, + 368, 127, 369, 37423, 127, 371, 45598, 127, 373, 36019, 127, 375, 47177, + 127, 377, 127, 378, 127, 379, 128, 479, 128, 479, 128, 481, 128, 481, 128, + 483, 128, 483, 31719, 31719, 128, 487, 128, 487, 128, 489, 128, 489, 128, + 491, 128, 491, 128, 493, 128, 493, 128, 495, 128, 495, 128, 497, 128, 497, + 128, 499, 128, 499, 128, 501, 128, 501, 128, 503, 128, 503, 128, 505, 128, + 505, 128, 507, 128, 507, 128, 509, 128, 509, 128, 350, 128, 350, 128, 352, + 128, 352, 128, 354, 128, 354, 128, 356, 128, 356, 128, 358, 128, 358, 128, + 360, 128, 360, 128, 511, 128, 511, 128, 363, 128, 363, 328, 16384, 41901, + 128, 367, 128, 367, 128, 369, 128, 369, 128, 371, 128, 371, 128, 372, 128, + 374, 128, 374, 128, 376, 128, 376, 128, 378, 128, 378, 129, 478, 129, 478, + 129, 480, 129, 480, 129, 482, + ] + # fmt: on + + self.assertEqual(txt_output[0].tolist()[1:-1], expected_tokens) + + def test_clip_tokenizer_unicode_decode(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer unicode decode" + + " test" + ) + + # fmt: off + input_tokens = [ + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, + 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, + 286, 287, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, + 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, + 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, + 346, 347, 348, 349, 10830, 41359, 1950, 126, 353, 20199, 126, 355, 126, + 356, 126, 357, 5811, 126, 359, 14434, 126, 361, 8436, 43593, 6858, 126, + 365, 41175, 126, 367, 12559, 126, 369, 126, 370, 14844, 126, 372, 126, + 373, 28059, 7599, 126, 376, 33613, 126, 378, 17133, 21259, 22229, 127, + 351, 127, 352, 47276, 127, 354, 127, 355, 127, 356, 37761, 4166, 127, 359, + 40520, 127, 361, 23928, 127, 362, 127, 363, 127, 364, 127, 365, 127, 366, + 27733, 127, 368, 127, 369, 37423, 16165, 45598, 127, 373, 36019, 127, 375, + 47177, 127, 377, 127, 378, 127, 509, 21259, 22229, 127, 351, 127, 352, + 47276, 127, 354, 127, 355, 127, 356, 37761, 4166, 127, 359, 40520, 127, + 361, 23928, 127, 362, 127, 363, 127, 364, 127, 365, 127, 366, 27733, 127, + 368, 127, 369, 37423, 127, 371, 45598, 127, 373, 36019, 127, 375, 47177, + 127, 377, 127, 378, 127, 379, 128, 479, 128, 479, 128, 481, 128, 481, 128, + 483, 128, 483, 31719, 31719, 128, 487, 128, 487, 128, 489, 128, 489, 128, + 491, 128, 491, 128, 493, 128, 493, 128, 495, 128, 495, 128, 497, 128, 497, + 128, 499, 128, 499, 128, 501, 128, 501, 128, 503, 128, 503, 128, 505, 128, + 505, 128, 507, 128, 507, 128, 509, 128, 509, 128, 350, 128, 350, 128, 352, + 128, 352, 128, 354, 128, 354, 128, 356, 128, 356, 128, 358, 128, 358, 128, + 360, 128, 360, 128, 511, 128, 511, 128, 363, 128, 363, 328, 16384, 41901, + 72, 329, 72, 329, 128, 369, 128, 369, 128, 371, 128, 371, 128, 372, 128, + 374, 128, 374, 128, 376, 128, 376, 128, 378, 128, 378, 129, 478, 129, 478, + 129, 480, 129, 480, 129, 482, + ] + # fmt: on + + input_tokens = torch.as_tensor([input_tokens]) + clip_tokenizer = transforms.CLIPTokenizer(pretrained_merges=True) + txt_output_str = clip_tokenizer.decode(input_tokens) + + expected_str = ( + """!"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\]^_`abcd""" # noqa: W605,E501 + + """efghijklmnopqrstuvwxyz{|}~¡¢£¤¥¦§¨©ª«¬®¯°±²³´µ¶·¸¹º»¼½¾¿àáâãäåæçèé""" + + """êëìíîïðñòóôõö×øùúûüýþßàáâãäåæçèéêëìíîïðñòóôõö÷øùúûüýþÿāāăăąąććĉĉċċ""" + + """ččďďđđēēĕĕėėęęěěĝĝğğġġģģĥĥħħĩĩīīĭĭįįi̇ıijijĵĵķķĸĺĺļļľľŀŀłłń""" + ) + self.assertEqual(len(txt_output_str), 1) + self.assertEqual(txt_output_str[0].replace(" ", ""), expected_str) + + def test_clip_tokenizer_truncate(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer truncate test" + ) + context_length = 5 + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, context_length=context_length, truncate=True + ) + text_input_str = "this is a test!" + text_output = clip_tokenizer(text_input_str) + + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), [49406, 589, 533, 320, 49407]) + + def test_clip_tokenizer_truncate_no_end_token(self) -> None: + if not _torchtext_has_clip_tokenizer: + raise unittest.SkipTest( + "torchtext >=0.12.0 not found, skipping ClipTokenizer truncate no" + + " end token test" + ) + context_length = 5 + clip_tokenizer = transforms.CLIPTokenizer( + pretrained_merges=True, + context_length=context_length, + end_token=None, + truncate=True, + ) + text_input_str = "this is a test!" + text_output = clip_tokenizer(text_input_str) + + self.assertEqual(list(text_output.shape), [1, context_length]) + self.assertEqual(text_output[0].tolist(), [49406, 589, 533, 320, 1628])