diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index a479c21004..3787d74ee1 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -90,6 +90,12 @@ from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import ( XLMRobertaClassifier, ) +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm import ( + XLMRobertaMaskedLM, +) +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import ( XLMRobertaPreprocessor, ) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py new file mode 100644 index 0000000000..3b25655a1e --- /dev/null +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -0,0 +1,159 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""XLM-RoBERTa masked lm model.""" + +import copy + +from tensorflow import keras + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.masked_lm_head import MaskedLMHead +from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer +from keras_nlp.models.task import Task +from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) +from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets +from keras_nlp.utils.keras_utils import is_xla_compatible +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLM") +class XLMRobertaMaskedLM(Task): + """An end-to-end XLM-RoBERTa model for the masked language modeling task. + + This model will train XLM-RoBERTa on a masked language modeling task. + The model will predict labels for a number of masked tokens in the + input data. For usage of this model with pre-trained weights, see the + `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case inputs can be raw string features during `fit()`, `predict()`, + and `evaluate()`. Inputs will be tokenized and dynamically masked during + training and evaluation. This is done by default when creating the model + with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/facebookresearch/fairseq). + + Args: + backbone: A `keras_nlp.models.XLMRobertaBackbone` instance. + preprocessor: A `keras_nlp.models.XLMRobertaMaskedLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example usage: + + Raw string inputs and pretrained backbone. + ```python + # Create a dataset with raw string features. Labels are inferred. + features = ["The quick brown fox jumped.", "I forgot my homework."] + + # Pretrained language model + # on an MLM task. + masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset( + "xlm_roberta_base_multi", + ) + masked_lm.fit(x=features, batch_size=2) + ``` + + # Re-compile (e.g., with a new learning rate). + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + # Access backbone programatically (e.g., to change `trainable`). + masked_lm.backbone.trainable = False + # Fit again. + masked_lm.fit(x=features, batch_size=2) + ``` + + Preprocessed integer data. + ```python + # Create a preprocessed dataset where 0 is the mask token. + features = { + "token_ids": tf.constant( + [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) + ), + "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) + } + # Labels are the original masked values. + labels = [[3, 5]] * 2 + + masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset( + "xlm_roberta_base_multi", + preprocessor=None, + ) + + masked_lm.fit(x=features, y=labels, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + backbone_outputs = backbone(backbone.input) + outputs = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + embedding_weights=backbone.token_embedding.embeddings, + intermediate_activation="gelu", + kernel_initializer=roberta_kernel_initializer(), + name="mlm_head", + )(backbone_outputs, inputs["mask_positions"]) + + # Instantiate using Functional API Model constructor. + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs, + ) + # All references to `self` below this line + self.backbone = backbone + self.preprocessor = preprocessor + + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + + @classproperty + def backbone_cls(cls): + return XLMRobertaBackbone + + @classproperty + def preprocessor_cls(cls): + return XLMRobertaMaskedLMPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py new file mode 100644 index 0000000000..2ed7baa0bf --- /dev/null +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py @@ -0,0 +1,184 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""XLM-RoBERTa masked language model preprocessor layer.""" + +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import ( + XLMRobertaPreprocessor, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLMPreprocessor") +class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor): + """XLM-RoBERTa preprocessing for the masked language modeling task. + + This preprocessing layer will prepare inputs for a masked language modeling + task. It is primarily intended for use with the + `keras_nlp.models.XLMRobertaMaskedLM` task model. Preprocessing will occur in + multiple steps. + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together with the appropriate `""`, `""` and + `""` tokens, i.e., adding a single `""` at the start of the + entire sequence, `""` between each segment, + and a `""` at the end of the entire sequence. + 3. Randomly select non-special tokens to mask, controlled by + `mask_selection_rate`. + 4. Construct a `(x, y, sample_weight)` tuple suitable for training with a + `keras_nlp.models.XLMRobertaMaskedLM` task model. + + Args: + tokenizer: A `keras_nlp.models.XLMRobertaTokenizer` instance. + sequence_length: int. The length of the packed inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + mask_selection_rate: float. The probability an input token will be + dynamically masked. + mask_selection_length: int. The maximum number of masked tokens + in a given sample. + mask_token_rate: float. The probability the a selected token will be + replaced with the mask token. + random_token_rate: float. The probability the a selected token will be + replaced with a random token from the vocabulary. A selected token + will be left as is with probability + `1 - mask_token_rate - random_token_rate`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset( + "xlm_roberta_base_multi" + ) + + # Tokenize and mask a single sentence. + preprocessor("The quick brown fox jumped.") + # Tokenize and mask a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + # Tokenize and mask sentence pairs. + # In this case, always convert input to tensors before calling the layer. + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset( + "xlm_roberta_base_multi" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + + # Map single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + mask_selection_rate=0.15, + mask_selection_length=96, + mask_token_rate=0.8, + random_token_rate=0.1, + **kwargs, + ): + super().__init__( + tokenizer, + sequence_length=sequence_length, + truncate=truncate, + **kwargs, + ) + + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=mask_selection_rate, + mask_selection_length=mask_selection_length, + mask_token_rate=mask_token_rate, + random_token_rate=random_token_rate, + vocabulary_size=tokenizer.vocabulary_size(), + mask_token_id=tokenizer.mask_token_id, + unselectable_token_ids=[ + tokenizer.start_token_id, + tokenizer.end_token_id, + tokenizer.pad_token_id, + ], + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "mask_selection_rate": self.masker.mask_selection_rate, + "mask_selection_length": self.masker.mask_selection_length, + "mask_token_rate": self.masker.mask_token_rate, + "random_token_rate": self.masker.random_token_rate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + f"{self.__class__.__name__} generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + + x = super().call(x) + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "padding_mask": padding_mask, + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py new file mode 100644 index 0000000000..2b573204db --- /dev/null +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py @@ -0,0 +1,174 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for XLM-RoBERTa masked language model preprocessor layer.""" + +import io +import os + +import pytest +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) +from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) + + +class XLMRobertaMaskedLMPreprocessorTest( + tf.test.TestCase, parameterized.TestCase +): + def setUp(self): + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=12, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="", + eos_piece="", + user_defined_symbols="[MASK]", + ) + self.proto = bytes_io.getvalue() + + self.tokenizer = XLMRobertaTokenizer(proto=self.proto) + self.preprocessor = XLMRobertaMaskedLMPreprocessor( + tokenizer=self.tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + + def test_preprocess_strings(self): + input_data = " brown fox quick" + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 0, 0]) + self.assertAllEqual(y, [7, 9, 11, 0, 0]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 0.0, 0.0]) + + def test_preprocess_list_of_strings(self): + input_data = [" brown fox quick"] * 13 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1]] * 13 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]] * 13 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 0, 0]] * 13) + self.assertAllEqual(y, [[7, 9, 11, 0, 0]] * 13) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 0.0, 0.0]] * 13) + + def test_preprocess_dataset(self): + sentences = tf.constant([" brown fox quick"] * 13) + ds = tf.data.Dataset.from_tensor_slices(sentences) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(13).take(1).get_single_element() + self.assertAllEqual( + x["token_ids"], [[0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1]] * 13 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]] * 13 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 0, 0]] * 13) + self.assertAllEqual(y, [[7, 9, 11, 0, 0]] * 13) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 0.0, 0.0]] * 13) + + def test_mask_multiple_sentences(self): + sentence_one = tf.constant(" airplane") + sentence_two = tf.constant(" round") + + x, y, sw = self.preprocessor((sentence_one, sentence_two)) + self.assertAllEqual( + x["token_ids"], [0, 2, 2, 2, 13, 2, 1, 1, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [4, 0, 0, 0, 0]) + self.assertAllEqual(y, [12, 0, 0, 0, 0]) + self.assertAllEqual(sw, [1.0, 0.0, 0.0, 0.0, 0.0]) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = XLMRobertaMaskedLMPreprocessor( + self.preprocessor.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + input_data = " quick brown fox" + + x, y, sw = no_mask_preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [0, 11, 7, 9, 2, 1, 1, 1, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) + + def test_serialization(self): + config = keras.utils.serialize_keras_object(self.preprocessor) + new_preprocessor = keras.utils.deserialize_keras_object(config) + self.assertEqual( + new_preprocessor.get_config(), + self.preprocessor.get_config(), + ) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + @pytest.mark.large + def test_saved_model(self, save_format, filename): + input_data = tf.constant([" quick brown fox"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + outputs = model(input_data)[0]["token_ids"] + restored_outputs = restored_model(input_data)[0]["token_ids"] + self.assertAllEqual(outputs, restored_outputs) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py new file mode 100644 index 0000000000..168a219da3 --- /dev/null +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py @@ -0,0 +1,134 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for XLM-RoBERTa masked language model.""" + +import io +import os + +import pytest +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm import ( + XLMRobertaMaskedLM, +) +from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) +from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) + + +class XLMRobertaMaskedLMTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the slow brown fox"] + ) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=5, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="", + eos_piece="", + user_defined_symbols="[MASK]", + ) + self.proto = bytes_io.getvalue() + + self.preprocessor = XLMRobertaMaskedLMPreprocessor( + XLMRobertaTokenizer(proto=self.proto), + sequence_length=5, + mask_selection_length=2, + ) + + self.backbone = XLMRobertaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_heads=2, + hidden_dim=2, + intermediate_dim=4, + max_sequence_length=self.preprocessor.packer.sequence_length, + ) + + self.masked_lm = XLMRobertaMaskedLM( + self.backbone, + preprocessor=self.preprocessor, + ) + + self.raw_batch = tf.constant( + ["the quick brown fox", "the slow brown fox"] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + self.raw_batch + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_masked_lm(self): + self.masked_lm(self.preprocessed_batch) + + def test_classifier_predict(self): + self.masked_lm.predict(self.raw_batch) + self.masked_lm.preprocessor = None + self.masked_lm.predict(self.preprocessed_batch) + + def test_classifier_fit(self): + self.masked_lm.fit(self.raw_dataset) + self.masked_lm.preprocessor = None + self.masked_lm.fit(self.preprocessed_dataset) + + def test_classifier_fit_no_xla(self): + self.masked_lm.preprocessor = None + self.masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), + jit_compile=False, + ) + self.masked_lm.fit(self.preprocessed_dataset) + + def test_serialization(self): + config = keras.utils.serialize_keras_object(self.masked_lm) + new_classifier = keras.utils.deserialize_keras_object(config) + self.assertEqual( + new_classifier.get_config(), + self.masked_lm.get_config(), + ) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + @pytest.mark.large + def test_saved_model(self, save_format, filename): + save_path = os.path.join(self.get_temp_dir(), filename) + self.masked_lm.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, XLMRobertaMaskedLM) + + model_output = self.masked_lm(self.preprocessed_batch) + restored_output = restored_model(self.preprocessed_batch) + + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index d679ab9ba3..13a8a7bac5 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -100,10 +100,11 @@ def __init__(self, proto, **kwargs): self.pad_token_id = 1 # self.end_token_id = 2 # self.unk_token_id = 3 # + self.mask_token_id = self.vocabulary_size() - 1 # def vocabulary_size(self): """Get the size of the tokenizer vocabulary.""" - return super().vocabulary_size() + 1 + return super().vocabulary_size() + 2 def get_vocabulary(self): """Get the size of the tokenizer vocabulary.""" @@ -112,10 +113,14 @@ def get_vocabulary(self): tf.range(super().vocabulary_size()) ) ) - return self._vocabulary_prefix + vocabulary[3:] + return self._vocabulary_prefix + vocabulary[3:] + [""] def id_to_token(self, id): """Convert an integer id to a string token.""" + + if id == self.mask_token_id: + return "" + if id < len(self._vocabulary_prefix): return self._vocabulary_prefix[id] @@ -139,25 +144,9 @@ def tokenize(self, inputs): # Shift the tokens IDs right by one. return tf.add(tokens, 1) - def detokenize(self, inputs): - if inputs.dtype == tf.string: - return super().detokenize(inputs) - - # Shift the tokens IDs left by one. - tokens = tf.subtract(inputs, 1) - - # Correct `unk_token_id`, `end_token_id`, `start_token_id`, respectively. - # Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the - # proto does not contain `pad_token_id`. This mapping of the pad token - # is done automatically by the above subtraction. - tokens = tf.where(tf.equal(tokens, self.unk_token_id - 1), 0, tokens) - tokens = tf.where(tf.equal(tokens, self.end_token_id - 1), 2, tokens) - tokens = tf.where(tf.equal(tokens, self.start_token_id - 1), 1, tokens) - - # Note: Even though we map `"" and `""` to the correct IDs, - # the `detokenize` method will return empty strings for these tokens. - # This is a vagary of the `sentencepiece` library. - return super().detokenize(tokens) + def detokenize(self, ids): + ids = tf.ragged.boolean_mask(ids, tf.not_equal(ids, self.mask_token_id)) + return super().detokenize(ids) @classproperty def presets(cls): diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py index 912403fc4e..6c216403c7 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py @@ -66,7 +66,7 @@ def test_unk_token(self): def test_detokenize(self): input_data = tf.constant([[4, 9, 5, 7]]) output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, tf.constant(["brown round earth is"])) def test_vocabulary(self): vocabulary = self.tokenizer.get_vocabulary() @@ -84,9 +84,10 @@ def test_vocabulary(self): "▁is", "▁quick", "▁round", + "", ], ) - self.assertEqual(self.tokenizer.vocabulary_size(), 11) + self.assertEqual(self.tokenizer.vocabulary_size(), 12) def test_id_to_token(self): print(self.tokenizer.id_to_token(9))