diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index 3d21e9ca85..9cdd68b2ff 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -45,6 +45,10 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): One of The encoding of the input text. Defaults to "UTF-8". output_encoding: One of ("UTF-8", "UTF-16-BE", or "UTF-32-BE"). The encoding of the output text. Defaults to "UTF-8". + vocabulary_size: Set the vocabulary `vocabulary_size`, + by clamping all codepoints to the range [0, vocabulary_size). + Effectively this will make the `vocabulary_size - 1` id the + the OOV value. Examples: @@ -130,6 +134,19 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): numpy=array([[ 105, 32, 108, 105, 107], [2350, 2376, 2306, 32, 2325]], dtype=int32)> + Tokenization with vocabulary_size. + >>> latin_ext_cutoff = 592 + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... vocabulary_size=latin_ext_cutoff) + >>> tokenizer("¿Cómo estás?") + + >>> tokenizer("आप कैसे हैं") + + Detokenization. >>> inputs = tf.constant([110, 105, 110, 106, 97], dtype=tf.int32) >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() @@ -168,6 +185,7 @@ def __init__( replacement_char: int = 65533, input_encoding: str = "UTF-8", output_encoding: str = "UTF-8", + vocabulary_size: int = None, **kwargs, ) -> None: # Check dtype and provide a default. @@ -213,6 +231,7 @@ def __init__( self.replacement_char = replacement_char self.input_encoding = input_encoding self.output_encoding = output_encoding + self.vocabulary_size = vocabulary_size def get_config(self) -> Dict[str, Any]: config = super().get_config() @@ -225,10 +244,16 @@ def get_config(self) -> Dict[str, Any]: "replacement_char": self.replacement_char, "input_encoding": self.input_encoding, "output_encoding": self.output_encoding, + "vocabulary_size": self.vocabulary_size, } ) return config + def vocabulary_size(self) -> int: + """Get the size of the tokenizer vocabulary. None implies no vocabulary + size was provided""" + return self.vocabulary_size + def tokenize(self, inputs): if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) @@ -260,6 +285,12 @@ def tokenize(self, inputs): if scalar_input: tokens = tf.squeeze(tokens, 0) + + # Optionally clamps the output code point values to be in the + # range [0, vocabulary_size) + if self.vocabulary_size: + tokens = tf.clip_by_value(tokens, 0, self.vocabulary_size - 1) + return tokens def detokenize(self, inputs): diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py index 34df6a5094..334a813961 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py @@ -59,6 +59,46 @@ def test_dense_output(self): ], ) + def test_tokenize_scalar_with_vocabulary_size(self): + input_data = "ninja" + tokenizer = UnicodeCharacterTokenizer(vocabulary_size=105) + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + + self.assertAllEqual(call_output, [104, 104, 104, 104, 97]) + self.assertAllEqual(tokenize_output, [104, 104, 104, 104, 97]) + + def test_tokenize_dense_with_vocabulary_size(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer( + sequence_length=10, vocabulary_size=105 + ) + call_output = tokenizer(input_data) + self.assertIsInstance(call_output, tf.Tensor) + self.assertAllEqual( + call_output, + [ + [104, 104, 104, 104, 97, 0, 0, 0, 0, 0], + [104, 97, 104, 104, 104, 97, 104, 0, 0, 0], + [104, 104, 104, 104, 0, 0, 0, 0, 0, 0], + ], + ) + + def test_tokenize_ragged_with_vocabulary_size(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer(vocabulary_size=105) + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + self.assertIsInstance(call_output, tf.RaggedTensor) + exp_outputs = [ + [104, 104, 104, 104, 97], + [104, 97, 104, 104, 104, 97, 104], + [104, 104, 104, 104], + ] + for i in range(call_output.shape[0]): + self.assertAllEqual(call_output[i], exp_outputs[i]) + self.assertAllEqual(tokenize_output[i], exp_outputs[i]) + def test_detokenize(self): input_data = tf.ragged.constant( [ @@ -232,6 +272,7 @@ def test_load_model_with_config(self): sequence_length=11, normalization_form="NFC", errors="strict", + vocabulary_size=None, ) cloned_tokenizer = UnicodeCharacterTokenizer.from_config( original_tokenizer.get_config() @@ -255,6 +296,7 @@ def test_config(self): normalization_form="NFC", errors="ignore", replacement_char=0, + vocabulary_size=100, ) exp_config = { "dtype": "int32", @@ -267,6 +309,7 @@ def test_config(self): "input_encoding": "UTF-8", "output_encoding": "UTF-8", "trainable": True, + "vocabulary_size": 100, } self.assertEqual(tokenizer.get_config(), exp_config) @@ -278,6 +321,7 @@ def test_config(self): replacement_char=0, input_encoding="UTF-16", output_encoding="UTF-16", + vocabulary_size=None, ) exp_config_different_encoding = { "dtype": "int32", @@ -290,6 +334,7 @@ def test_config(self): "input_encoding": "UTF-16", "output_encoding": "UTF-16", "trainable": True, + "vocabulary_size": None, } self.assertEqual( tokenize_different_encoding.get_config(), @@ -305,6 +350,7 @@ def test_saving(self): sequence_length=20, normalization_form="NFKC", errors="replace", + vocabulary_size=None, ) inputs = keras.Input(dtype="string", shape=()) outputs = tokenizer(inputs)