From 1451f43647123219ab4ac8cd6ad5d015c83fd168 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Wed, 4 May 2022 01:28:39 +0530 Subject: [PATCH 1/6] Clamped Values --- .../tokenizers/unicode_character_tokenizer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index 3d21e9ca85..b30c681ddc 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -45,6 +45,10 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): One of The encoding of the input text. Defaults to "UTF-8". output_encoding: One of ("UTF-8", "UTF-16-BE", or "UTF-32-BE"). The encoding of the output text. Defaults to "UTF-8". + vocabulary_size: Set the vocabulary `vocabulary_size`, + by clamping all codepoints to the range [0, vocabulary_size). + Effectively this will make the `vocabulary_size - 1` id the + the OOV token. Examples: @@ -168,6 +172,7 @@ def __init__( replacement_char: int = 65533, input_encoding: str = "UTF-8", output_encoding: str = "UTF-8", + vocabulary_size: int = None, **kwargs, ) -> None: # Check dtype and provide a default. @@ -213,6 +218,7 @@ def __init__( self.replacement_char = replacement_char self.input_encoding = input_encoding self.output_encoding = output_encoding + self.vocabulary_size = vocabulary_size def get_config(self) -> Dict[str, Any]: config = super().get_config() @@ -225,10 +231,16 @@ def get_config(self) -> Dict[str, Any]: "replacement_char": self.replacement_char, "input_encoding": self.input_encoding, "output_encoding": self.output_encoding, + "vocabulary_size": self.vocabulary_size, } ) return config + def vocabulary_size(self) -> int: + """Get the size of the tokenizer vocabulary. None implies no vocabulary + size was provided""" + return self.vocabulary_size + def tokenize(self, inputs): if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) @@ -260,6 +272,12 @@ def tokenize(self, inputs): if scalar_input: tokens = tf.squeeze(tokens, 0) + + # Optionally clamps the output code point values to be in the + # range [0, vocabulary_size) + if self.vocabulary_size: + tokens = tf.clip_by_value(tokens, 0, self.vocabulary_size - 1) + return tokens def detokenize(self, inputs): @@ -271,3 +289,5 @@ def detokenize(self, inputs): output_encoding=self.output_encoding, ) return encoded_string + + From 9d62980601fa06bda8978b7a00165b48cb655497 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Wed, 4 May 2022 01:48:38 +0530 Subject: [PATCH 2/6] Added New Tests, Fixed Old Tests --- .../unicode_character_tokenizer_test.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py index 34df6a5094..3300398e28 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py @@ -59,6 +59,45 @@ def test_dense_output(self): ], ) + def test_tokenize_scalar_with_vocabulary_size(self): + input_data = "ninja" + tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 105) + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + + self.assertAllEqual(call_output, [104, 104, 104, 104, 97]) + self.assertAllEqual(tokenize_output, [104, 104, 104, 104, 97]) + + def test_tokenize_dense_with_vocabulary_size(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer(sequence_length = 10, + vocabulary_size = 105) + call_output = tokenizer(input_data) + self.assertIsInstance(call_output, tf.Tensor) + self.assertAllEqual( + call_output, + [ + [104, 104, 104, 104, 97, 0, 0, 0, 0, 0], + [104, 97, 104, 104, 104, 97, 104, 0, 0, 0], + [104, 104, 104, 104, 0, 0, 0, 0, 0, 0], + ], + ) + + def test_tokenize_ragged_with_vocabulary_size(self): + input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) + tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 105) + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + self.assertIsInstance(call_output, tf.RaggedTensor) + exp_outputs = [ + [104, 104, 104, 104, 97], + [104, 97, 104, 104, 104, 97, 104], + [104, 104, 104, 104], + ] + for i in range(call_output.shape[0]): + self.assertAllEqual(call_output[i], exp_outputs[i]) + self.assertAllEqual(tokenize_output[i], exp_outputs[i]) + def test_detokenize(self): input_data = tf.ragged.constant( [ @@ -232,6 +271,7 @@ def test_load_model_with_config(self): sequence_length=11, normalization_form="NFC", errors="strict", + vocabulary_size= None, ) cloned_tokenizer = UnicodeCharacterTokenizer.from_config( original_tokenizer.get_config() @@ -255,6 +295,7 @@ def test_config(self): normalization_form="NFC", errors="ignore", replacement_char=0, + vocabulary_size= None, ) exp_config = { "dtype": "int32", @@ -267,6 +308,7 @@ def test_config(self): "input_encoding": "UTF-8", "output_encoding": "UTF-8", "trainable": True, + "vocabulary_size": None, } self.assertEqual(tokenizer.get_config(), exp_config) @@ -278,6 +320,7 @@ def test_config(self): replacement_char=0, input_encoding="UTF-16", output_encoding="UTF-16", + vocabulary_size= None, ) exp_config_different_encoding = { "dtype": "int32", @@ -290,12 +333,39 @@ def test_config(self): "input_encoding": "UTF-16", "output_encoding": "UTF-16", "trainable": True, + "vocabulary_size": None, } self.assertEqual( tokenize_different_encoding.get_config(), exp_config_different_encoding, ) + tokenize_different_vocabular_size = UnicodeCharacterTokenizer( + name="unicode_character_tokenizer_config_gen", + lowercase=False, + sequence_length=8, + errors="ignore", + replacement_char=0, + vocabulary_size= 100, + ) + exp_config_different_vocabular_size = { + "dtype": "int32", + "errors": "ignore", + "lowercase": False, + "name": "unicode_character_tokenizer_config_gen", + "normalization_form": None, + "replacement_char": 0, + "sequence_length": 8, + "input_encoding": "UTF-8", + "output_encoding": "UTF-8", + "trainable": True, + "vocabulary_size": 100, + } + self.assertEqual( + tokenize_different_vocabular_size.get_config(), + exp_config_different_vocabular_size, + ) + def test_saving(self): input_data = tf.constant(["ninjas and samurais", "time travel"]) @@ -305,6 +375,7 @@ def test_saving(self): sequence_length=20, normalization_form="NFKC", errors="replace", + vocabulary_size= None, ) inputs = keras.Input(dtype="string", shape=()) outputs = tokenizer(inputs) From 53714019b4dfb3208c2caac299d8a567bdf24d96 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Wed, 4 May 2022 01:53:05 +0530 Subject: [PATCH 3/6] Ran formatters --- .../tokenizers/unicode_character_tokenizer.py | 4 +-- .../unicode_character_tokenizer_test.py | 25 ++++++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index b30c681ddc..90294c3ae6 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -273,7 +273,7 @@ def tokenize(self, inputs): if scalar_input: tokens = tf.squeeze(tokens, 0) - # Optionally clamps the output code point values to be in the + # Optionally clamps the output code point values to be in the # range [0, vocabulary_size) if self.vocabulary_size: tokens = tf.clip_by_value(tokens, 0, self.vocabulary_size - 1) @@ -289,5 +289,3 @@ def detokenize(self, inputs): output_encoding=self.output_encoding, ) return encoded_string - - diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py index 3300398e28..3359a8f948 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py @@ -61,17 +61,18 @@ def test_dense_output(self): def test_tokenize_scalar_with_vocabulary_size(self): input_data = "ninja" - tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 105) + tokenizer = UnicodeCharacterTokenizer(vocabulary_size=105) call_output = tokenizer(input_data) tokenize_output = tokenizer.tokenize(input_data) - self.assertAllEqual(call_output, [104, 104, 104, 104, 97]) - self.assertAllEqual(tokenize_output, [104, 104, 104, 104, 97]) + self.assertAllEqual(call_output, [104, 104, 104, 104, 97]) + self.assertAllEqual(tokenize_output, [104, 104, 104, 104, 97]) def test_tokenize_dense_with_vocabulary_size(self): input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) - tokenizer = UnicodeCharacterTokenizer(sequence_length = 10, - vocabulary_size = 105) + tokenizer = UnicodeCharacterTokenizer( + sequence_length=10, vocabulary_size=105 + ) call_output = tokenizer(input_data) self.assertIsInstance(call_output, tf.Tensor) self.assertAllEqual( @@ -82,10 +83,10 @@ def test_tokenize_dense_with_vocabulary_size(self): [104, 104, 104, 104, 0, 0, 0, 0, 0, 0], ], ) - + def test_tokenize_ragged_with_vocabulary_size(self): input_data = tf.constant(["ninja", "samurai", "▀▁▂▃"]) - tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 105) + tokenizer = UnicodeCharacterTokenizer(vocabulary_size=105) call_output = tokenizer(input_data) tokenize_output = tokenizer.tokenize(input_data) self.assertIsInstance(call_output, tf.RaggedTensor) @@ -271,7 +272,7 @@ def test_load_model_with_config(self): sequence_length=11, normalization_form="NFC", errors="strict", - vocabulary_size= None, + vocabulary_size=None, ) cloned_tokenizer = UnicodeCharacterTokenizer.from_config( original_tokenizer.get_config() @@ -295,7 +296,7 @@ def test_config(self): normalization_form="NFC", errors="ignore", replacement_char=0, - vocabulary_size= None, + vocabulary_size=None, ) exp_config = { "dtype": "int32", @@ -320,7 +321,7 @@ def test_config(self): replacement_char=0, input_encoding="UTF-16", output_encoding="UTF-16", - vocabulary_size= None, + vocabulary_size=None, ) exp_config_different_encoding = { "dtype": "int32", @@ -346,7 +347,7 @@ def test_config(self): sequence_length=8, errors="ignore", replacement_char=0, - vocabulary_size= 100, + vocabulary_size=100, ) exp_config_different_vocabular_size = { "dtype": "int32", @@ -375,7 +376,7 @@ def test_saving(self): sequence_length=20, normalization_form="NFKC", errors="replace", - vocabulary_size= None, + vocabulary_size=None, ) inputs = keras.Input(dtype="string", shape=()) outputs = tokenizer(inputs) From 8ce06b226b1700d6fbc99491707a0c783422eb26 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Wed, 4 May 2022 11:19:51 +0530 Subject: [PATCH 4/6] Fixes based on reviews --- .../tokenizers/unicode_character_tokenizer.py | 10 ++++++- .../unicode_character_tokenizer_test.py | 30 ++----------------- 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index 90294c3ae6..1a21ae6d88 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -48,7 +48,7 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): vocabulary_size: Set the vocabulary `vocabulary_size`, by clamping all codepoints to the range [0, vocabulary_size). Effectively this will make the `vocabulary_size - 1` id the - the OOV token. + the OOV value. Examples: @@ -134,6 +134,14 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): numpy=array([[ 105, 32, 108, 105, 107], [2350, 2376, 2306, 32, 2325]], dtype=int32)> + Tokenization with vocabulary_size. + >>> input_data = "आप कैसे हैं" + >>> tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 592) + >>> tokenizer(input_data) + + Detokenization. >>> inputs = tf.constant([110, 105, 110, 106, 97], dtype=tf.int32) >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer() diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py index 3359a8f948..334a813961 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer_test.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer_test.py @@ -296,7 +296,7 @@ def test_config(self): normalization_form="NFC", errors="ignore", replacement_char=0, - vocabulary_size=None, + vocabulary_size=100, ) exp_config = { "dtype": "int32", @@ -309,7 +309,7 @@ def test_config(self): "input_encoding": "UTF-8", "output_encoding": "UTF-8", "trainable": True, - "vocabulary_size": None, + "vocabulary_size": 100, } self.assertEqual(tokenizer.get_config(), exp_config) @@ -341,32 +341,6 @@ def test_config(self): exp_config_different_encoding, ) - tokenize_different_vocabular_size = UnicodeCharacterTokenizer( - name="unicode_character_tokenizer_config_gen", - lowercase=False, - sequence_length=8, - errors="ignore", - replacement_char=0, - vocabulary_size=100, - ) - exp_config_different_vocabular_size = { - "dtype": "int32", - "errors": "ignore", - "lowercase": False, - "name": "unicode_character_tokenizer_config_gen", - "normalization_form": None, - "replacement_char": 0, - "sequence_length": 8, - "input_encoding": "UTF-8", - "output_encoding": "UTF-8", - "trainable": True, - "vocabulary_size": 100, - } - self.assertEqual( - tokenize_different_vocabular_size.get_config(), - exp_config_different_vocabular_size, - ) - def test_saving(self): input_data = tf.constant(["ninjas and samurais", "time travel"]) From 1b3113af442708c6b907787405133a8c48c9f2f9 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 4 May 2022 23:19:13 -0700 Subject: [PATCH 5/6] Show vocab cutoff on latin and non-latin chars --- keras_nlp/tokenizers/unicode_character_tokenizer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index 1a21ae6d88..9934302114 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -135,9 +135,14 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): [2350, 2376, 2306, 32, 2325]], dtype=int32)> Tokenization with vocabulary_size. - >>> input_data = "आप कैसे हैं" - >>> tokenizer = UnicodeCharacterTokenizer(vocabulary_size = 592) - >>> tokenizer(input_data) + >>> latin_ext_cutoff = 592 + >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( + ... vocabulary_size=latin_ext_cutoff) + >>> tokenizer("cómo estás") + + >>> tokenizer("आप कैसे हैं") From 35863ad11e5fcbd1e3261a31e754b757e6e32fb5 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 4 May 2022 23:24:34 -0700 Subject: [PATCH 6/6] Add puctuation --- keras_nlp/tokenizers/unicode_character_tokenizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index 9934302114..9cdd68b2ff 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -138,9 +138,9 @@ class UnicodeCharacterTokenizer(tokenizer.Tokenizer): >>> latin_ext_cutoff = 592 >>> tokenizer = keras_nlp.tokenizers.UnicodeCharacterTokenizer( ... vocabulary_size=latin_ext_cutoff) - >>> tokenizer("cómo estás") + >>> tokenizer("¿Cómo estás?") >>> tokenizer("आप कैसे हैं")