diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py
index a479c21004..3787d74ee1 100644
--- a/keras_nlp/models/__init__.py
+++ b/keras_nlp/models/__init__.py
@@ -90,6 +90,12 @@
from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import (
XLMRobertaClassifier,
)
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm import (
+ XLMRobertaMaskedLM,
+)
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
+ XLMRobertaMaskedLMPreprocessor,
+)
from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import (
XLMRobertaPreprocessor,
)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py
new file mode 100644
index 0000000000..3b25655a1e
--- /dev/null
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py
@@ -0,0 +1,159 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""XLM-RoBERTa masked lm model."""
+
+import copy
+
+from tensorflow import keras
+
+from keras_nlp.api_export import keras_nlp_export
+from keras_nlp.layers.masked_lm_head import MaskedLMHead
+from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer
+from keras_nlp.models.task import Task
+from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
+ XLMRobertaMaskedLMPreprocessor,
+)
+from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets
+from keras_nlp.utils.keras_utils import is_xla_compatible
+from keras_nlp.utils.python_utils import classproperty
+
+
+@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLM")
+class XLMRobertaMaskedLM(Task):
+ """An end-to-end XLM-RoBERTa model for the masked language modeling task.
+
+ This model will train XLM-RoBERTa on a masked language modeling task.
+ The model will predict labels for a number of masked tokens in the
+ input data. For usage of this model with pre-trained weights, see the
+ `from_preset()` method.
+
+ This model can optionally be configured with a `preprocessor` layer, in
+ which case inputs can be raw string features during `fit()`, `predict()`,
+ and `evaluate()`. Inputs will be tokenized and dynamically masked during
+ training and evaluation. This is done by default when creating the model
+ with `from_preset()`.
+
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
+ warranties or conditions of any kind. The underlying model is provided by a
+ third party and subject to a separate license, available
+ [here](https://github.com/facebookresearch/fairseq).
+
+ Args:
+ backbone: A `keras_nlp.models.XLMRobertaBackbone` instance.
+ preprocessor: A `keras_nlp.models.XLMRobertaMaskedLMPreprocessor` or
+ `None`. If `None`, this model will not apply preprocessing, and
+ inputs should be preprocessed before calling the model.
+
+ Example usage:
+
+ Raw string inputs and pretrained backbone.
+ ```python
+ # Create a dataset with raw string features. Labels are inferred.
+ features = ["The quick brown fox jumped.", "I forgot my homework."]
+
+ # Pretrained language model
+ # on an MLM task.
+ masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset(
+ "xlm_roberta_base_multi",
+ )
+ masked_lm.fit(x=features, batch_size=2)
+ ```
+
+ # Re-compile (e.g., with a new learning rate).
+ masked_lm.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ optimizer=keras.optimizers.Adam(5e-5),
+ jit_compile=True,
+ )
+ # Access backbone programatically (e.g., to change `trainable`).
+ masked_lm.backbone.trainable = False
+ # Fit again.
+ masked_lm.fit(x=features, batch_size=2)
+ ```
+
+ Preprocessed integer data.
+ ```python
+ # Create a preprocessed dataset where 0 is the mask token.
+ features = {
+ "token_ids": tf.constant(
+ [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
+ ),
+ "padding_mask": tf.constant(
+ [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
+ ),
+ "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2))
+ }
+ # Labels are the original masked values.
+ labels = [[3, 5]] * 2
+
+ masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset(
+ "xlm_roberta_base_multi",
+ preprocessor=None,
+ )
+
+ masked_lm.fit(x=features, y=labels, batch_size=2)
+ ```
+ """
+
+ def __init__(
+ self,
+ backbone,
+ preprocessor=None,
+ **kwargs,
+ ):
+ inputs = {
+ **backbone.input,
+ "mask_positions": keras.Input(
+ shape=(None,), dtype="int32", name="mask_positions"
+ ),
+ }
+ backbone_outputs = backbone(backbone.input)
+ outputs = MaskedLMHead(
+ vocabulary_size=backbone.vocabulary_size,
+ embedding_weights=backbone.token_embedding.embeddings,
+ intermediate_activation="gelu",
+ kernel_initializer=roberta_kernel_initializer(),
+ name="mlm_head",
+ )(backbone_outputs, inputs["mask_positions"])
+
+ # Instantiate using Functional API Model constructor.
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ include_preprocessing=preprocessor is not None,
+ **kwargs,
+ )
+ # All references to `self` below this line
+ self.backbone = backbone
+ self.preprocessor = preprocessor
+
+ self.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ optimizer=keras.optimizers.Adam(5e-5),
+ weighted_metrics=keras.metrics.SparseCategoricalAccuracy(),
+ jit_compile=is_xla_compatible(self),
+ )
+
+ @classproperty
+ def backbone_cls(cls):
+ return XLMRobertaBackbone
+
+ @classproperty
+ def preprocessor_cls(cls):
+ return XLMRobertaMaskedLMPreprocessor
+
+ @classproperty
+ def presets(cls):
+ return copy.deepcopy(backbone_presets)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py
new file mode 100644
index 0000000000..2ed7baa0bf
--- /dev/null
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py
@@ -0,0 +1,184 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XLM-RoBERTa masked language model preprocessor layer."""
+
+from absl import logging
+
+from keras_nlp.api_export import keras_nlp_export
+from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
+from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import (
+ XLMRobertaPreprocessor,
+)
+from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
+
+
+@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLMPreprocessor")
+class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor):
+ """XLM-RoBERTa preprocessing for the masked language modeling task.
+
+ This preprocessing layer will prepare inputs for a masked language modeling
+ task. It is primarily intended for use with the
+ `keras_nlp.models.XLMRobertaMaskedLM` task model. Preprocessing will occur in
+ multiple steps.
+
+ 1. Tokenize any number of input segments using the `tokenizer`.
+ 2. Pack the inputs together with the appropriate `""`, `""` and
+ `""` tokens, i.e., adding a single `""` at the start of the
+ entire sequence, `""` between each segment,
+ and a `""` at the end of the entire sequence.
+ 3. Randomly select non-special tokens to mask, controlled by
+ `mask_selection_rate`.
+ 4. Construct a `(x, y, sample_weight)` tuple suitable for training with a
+ `keras_nlp.models.XLMRobertaMaskedLM` task model.
+
+ Args:
+ tokenizer: A `keras_nlp.models.XLMRobertaTokenizer` instance.
+ sequence_length: int. The length of the packed inputs.
+ truncate: string. The algorithm to truncate a list of batched segments
+ to fit within `sequence_length`. The value can be either
+ `round_robin` or `waterfall`:
+ - `"round_robin"`: Available space is assigned one token at a
+ time in a round-robin fashion to the inputs that still need
+ some, until the limit is reached.
+ - `"waterfall"`: The allocation of the budget is done using a
+ "waterfall" algorithm that allocates quota in a
+ left-to-right manner and fills up the buckets until we run
+ out of budget. It supports an arbitrary number of segments.
+ mask_selection_rate: float. The probability an input token will be
+ dynamically masked.
+ mask_selection_length: int. The maximum number of masked tokens
+ in a given sample.
+ mask_token_rate: float. The probability the a selected token will be
+ replaced with the mask token.
+ random_token_rate: float. The probability the a selected token will be
+ replaced with a random token from the vocabulary. A selected token
+ will be left as is with probability
+ `1 - mask_token_rate - random_token_rate`.
+
+ Call arguments:
+ x: A tensor of single string sequences, or a tuple of multiple
+ tensor sequences to be packed together. Inputs may be batched or
+ unbatched. For single sequences, raw python inputs will be converted
+ to tensors. For multiple sequences, pass tensors directly.
+ y: Label data. Should always be `None` as the layer generates labels.
+ sample_weight: Label weights. Should always be `None` as the layer
+ generates label weights.
+
+ Examples:
+ ```python
+ # Load the preprocessor from a preset.
+ preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset(
+ "xlm_roberta_base_multi"
+ )
+
+ # Tokenize and mask a single sentence.
+ preprocessor("The quick brown fox jumped.")
+ # Tokenize and mask a batch of single sentences.
+ preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
+ # Tokenize and mask sentence pairs.
+ # In this case, always convert input to tensors before calling the layer.
+ first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
+ second = tf.constant(["The fox tripped.", "Oh look, a whale."])
+ preprocessor((first, second))
+ ```
+
+ Mapping with `tf.data.Dataset`.
+ ```python
+ preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset(
+ "xlm_roberta_base_multi"
+ )
+ first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."])
+ second = tf.constant(["The fox tripped.", "Oh look, a whale."])
+
+ # Map single sentences.
+ ds = tf.data.Dataset.from_tensor_slices(first)
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+
+ # Map sentence pairs.
+ ds = tf.data.Dataset.from_tensor_slices((first, second))
+ # Watch out for tf.data's default unpacking of tuples here!
+ # Best to invoke the `preprocessor` directly in this case.
+ ds = ds.map(
+ lambda first, second: preprocessor(x=(first, second)),
+ num_parallel_calls=tf.data.AUTOTUNE,
+ )
+ ```
+ ```
+ """
+
+ def __init__(
+ self,
+ tokenizer,
+ sequence_length=512,
+ truncate="round_robin",
+ mask_selection_rate=0.15,
+ mask_selection_length=96,
+ mask_token_rate=0.8,
+ random_token_rate=0.1,
+ **kwargs,
+ ):
+ super().__init__(
+ tokenizer,
+ sequence_length=sequence_length,
+ truncate=truncate,
+ **kwargs,
+ )
+
+ self.masker = MaskedLMMaskGenerator(
+ mask_selection_rate=mask_selection_rate,
+ mask_selection_length=mask_selection_length,
+ mask_token_rate=mask_token_rate,
+ random_token_rate=random_token_rate,
+ vocabulary_size=tokenizer.vocabulary_size(),
+ mask_token_id=tokenizer.mask_token_id,
+ unselectable_token_ids=[
+ tokenizer.start_token_id,
+ tokenizer.end_token_id,
+ tokenizer.pad_token_id,
+ ],
+ )
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "mask_selection_rate": self.masker.mask_selection_rate,
+ "mask_selection_length": self.masker.mask_selection_length,
+ "mask_token_rate": self.masker.mask_token_rate,
+ "random_token_rate": self.masker.random_token_rate,
+ }
+ )
+ return config
+
+ def call(self, x, y=None, sample_weight=None):
+ if y is not None or sample_weight is not None:
+ logging.warning(
+ f"{self.__class__.__name__} generates `y` and `sample_weight` "
+ "based on your input data, but your data already contains `y` "
+ "or `sample_weight`. Your `y` and `sample_weight` will be "
+ "ignored."
+ )
+
+ x = super().call(x)
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
+ masker_outputs = self.masker(token_ids)
+ x = {
+ "token_ids": masker_outputs["token_ids"],
+ "padding_mask": padding_mask,
+ "mask_positions": masker_outputs["mask_positions"],
+ }
+ y = masker_outputs["mask_ids"]
+ sample_weight = masker_outputs["mask_weights"]
+ return pack_x_y_sample_weight(x, y, sample_weight)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py
new file mode 100644
index 0000000000..2b573204db
--- /dev/null
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py
@@ -0,0 +1,174 @@
+# Copyright 2022 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for XLM-RoBERTa masked language model preprocessor layer."""
+
+import io
+import os
+
+import pytest
+import sentencepiece
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
+ XLMRobertaMaskedLMPreprocessor,
+)
+from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import (
+ XLMRobertaTokenizer,
+)
+
+
+class XLMRobertaMaskedLMPreprocessorTest(
+ tf.test.TestCase, parameterized.TestCase
+):
+ def setUp(self):
+ bytes_io = io.BytesIO()
+ vocab_data = tf.data.Dataset.from_tensor_slices(
+ ["the quick brown fox", "the earth is round"]
+ )
+ sentencepiece.SentencePieceTrainer.train(
+ sentence_iterator=vocab_data.as_numpy_iterator(),
+ model_writer=bytes_io,
+ vocab_size=12,
+ model_type="WORD",
+ pad_id=0,
+ unk_id=1,
+ bos_id=2,
+ eos_id=3,
+ pad_piece="",
+ unk_piece="",
+ bos_piece="",
+ eos_piece="",
+ user_defined_symbols="[MASK]",
+ )
+ self.proto = bytes_io.getvalue()
+
+ self.tokenizer = XLMRobertaTokenizer(proto=self.proto)
+ self.preprocessor = XLMRobertaMaskedLMPreprocessor(
+ tokenizer=self.tokenizer,
+ # Simplify out testing by masking every available token.
+ mask_selection_rate=1.0,
+ mask_token_rate=1.0,
+ random_token_rate=0.0,
+ mask_selection_length=5,
+ sequence_length=12,
+ )
+
+ def test_preprocess_strings(self):
+ input_data = " brown fox quick"
+
+ x, y, sw = self.preprocessor(input_data)
+ self.assertAllEqual(
+ x["token_ids"], [0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1]
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
+ )
+ self.assertAllEqual(x["mask_positions"], [1, 2, 3, 0, 0])
+ self.assertAllEqual(y, [7, 9, 11, 0, 0])
+ self.assertAllEqual(sw, [1.0, 1.0, 1.0, 0.0, 0.0])
+
+ def test_preprocess_list_of_strings(self):
+ input_data = [" brown fox quick"] * 13
+
+ x, y, sw = self.preprocessor(input_data)
+ self.assertAllEqual(
+ x["token_ids"], [[0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1]] * 13
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]] * 13
+ )
+ self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 0, 0]] * 13)
+ self.assertAllEqual(y, [[7, 9, 11, 0, 0]] * 13)
+ self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 0.0, 0.0]] * 13)
+
+ def test_preprocess_dataset(self):
+ sentences = tf.constant([" brown fox quick"] * 13)
+ ds = tf.data.Dataset.from_tensor_slices(sentences)
+ ds = ds.map(self.preprocessor)
+ x, y, sw = ds.batch(13).take(1).get_single_element()
+ self.assertAllEqual(
+ x["token_ids"], [[0, 13, 13, 13, 2, 1, 1, 1, 1, 1, 1, 1]] * 13
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]] * 13
+ )
+ self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 0, 0]] * 13)
+ self.assertAllEqual(y, [[7, 9, 11, 0, 0]] * 13)
+ self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 0.0, 0.0]] * 13)
+
+ def test_mask_multiple_sentences(self):
+ sentence_one = tf.constant(" airplane")
+ sentence_two = tf.constant(" round")
+
+ x, y, sw = self.preprocessor((sentence_one, sentence_two))
+ self.assertAllEqual(
+ x["token_ids"], [0, 2, 2, 2, 13, 2, 1, 1, 1, 1, 1, 1]
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
+ )
+ self.assertAllEqual(x["mask_positions"], [4, 0, 0, 0, 0])
+ self.assertAllEqual(y, [12, 0, 0, 0, 0])
+ self.assertAllEqual(sw, [1.0, 0.0, 0.0, 0.0, 0.0])
+
+ def test_no_masking_zero_rate(self):
+ no_mask_preprocessor = XLMRobertaMaskedLMPreprocessor(
+ self.preprocessor.tokenizer,
+ mask_selection_rate=0.0,
+ mask_selection_length=5,
+ sequence_length=12,
+ )
+ input_data = " quick brown fox"
+
+ x, y, sw = no_mask_preprocessor(input_data)
+ self.assertAllEqual(
+ x["token_ids"], [0, 11, 7, 9, 2, 1, 1, 1, 1, 1, 1, 1]
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
+ )
+ self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0])
+ self.assertAllEqual(y, [0, 0, 0, 0, 0])
+ self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0])
+
+ def test_serialization(self):
+ config = keras.utils.serialize_keras_object(self.preprocessor)
+ new_preprocessor = keras.utils.deserialize_keras_object(config)
+ self.assertEqual(
+ new_preprocessor.get_config(),
+ self.preprocessor.get_config(),
+ )
+
+ @parameterized.named_parameters(
+ ("tf_format", "tf", "model"),
+ ("keras_format", "keras_v3", "model.keras"),
+ )
+ @pytest.mark.large
+ def test_saved_model(self, save_format, filename):
+ input_data = tf.constant([" quick brown fox"])
+
+ inputs = keras.Input(dtype="string", shape=())
+ outputs = self.preprocessor(inputs)
+ model = keras.Model(inputs, outputs)
+
+ path = os.path.join(self.get_temp_dir(), filename)
+ model.save(path, save_format=save_format)
+
+ restored_model = keras.models.load_model(path)
+ outputs = model(input_data)[0]["token_ids"]
+ restored_outputs = restored_model(input_data)[0]["token_ids"]
+ self.assertAllEqual(outputs, restored_outputs)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py
new file mode 100644
index 0000000000..168a219da3
--- /dev/null
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py
@@ -0,0 +1,134 @@
+# Copyright 2022 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for XLM-RoBERTa masked language model."""
+
+import io
+import os
+
+import pytest
+import sentencepiece
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm import (
+ XLMRobertaMaskedLM,
+)
+from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import (
+ XLMRobertaMaskedLMPreprocessor,
+)
+from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import (
+ XLMRobertaTokenizer,
+)
+
+
+class XLMRobertaMaskedLMTest(tf.test.TestCase, parameterized.TestCase):
+ def setUp(self):
+ bytes_io = io.BytesIO()
+ vocab_data = tf.data.Dataset.from_tensor_slices(
+ ["the quick brown fox", "the slow brown fox"]
+ )
+ sentencepiece.SentencePieceTrainer.train(
+ sentence_iterator=vocab_data.as_numpy_iterator(),
+ model_writer=bytes_io,
+ vocab_size=5,
+ model_type="WORD",
+ pad_id=0,
+ unk_id=1,
+ bos_id=2,
+ eos_id=3,
+ pad_piece="",
+ unk_piece="",
+ bos_piece="",
+ eos_piece="",
+ user_defined_symbols="[MASK]",
+ )
+ self.proto = bytes_io.getvalue()
+
+ self.preprocessor = XLMRobertaMaskedLMPreprocessor(
+ XLMRobertaTokenizer(proto=self.proto),
+ sequence_length=5,
+ mask_selection_length=2,
+ )
+
+ self.backbone = XLMRobertaBackbone(
+ vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ num_layers=2,
+ num_heads=2,
+ hidden_dim=2,
+ intermediate_dim=4,
+ max_sequence_length=self.preprocessor.packer.sequence_length,
+ )
+
+ self.masked_lm = XLMRobertaMaskedLM(
+ self.backbone,
+ preprocessor=self.preprocessor,
+ )
+
+ self.raw_batch = tf.constant(
+ ["the quick brown fox", "the slow brown fox"]
+ )
+ self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
+ self.raw_dataset = tf.data.Dataset.from_tensor_slices(
+ self.raw_batch
+ ).batch(2)
+ self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
+
+ def test_valid_call_masked_lm(self):
+ self.masked_lm(self.preprocessed_batch)
+
+ def test_classifier_predict(self):
+ self.masked_lm.predict(self.raw_batch)
+ self.masked_lm.preprocessor = None
+ self.masked_lm.predict(self.preprocessed_batch)
+
+ def test_classifier_fit(self):
+ self.masked_lm.fit(self.raw_dataset)
+ self.masked_lm.preprocessor = None
+ self.masked_lm.fit(self.preprocessed_dataset)
+
+ def test_classifier_fit_no_xla(self):
+ self.masked_lm.preprocessor = None
+ self.masked_lm.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
+ jit_compile=False,
+ )
+ self.masked_lm.fit(self.preprocessed_dataset)
+
+ def test_serialization(self):
+ config = keras.utils.serialize_keras_object(self.masked_lm)
+ new_classifier = keras.utils.deserialize_keras_object(config)
+ self.assertEqual(
+ new_classifier.get_config(),
+ self.masked_lm.get_config(),
+ )
+
+ @parameterized.named_parameters(
+ ("tf_format", "tf", "model"),
+ ("keras_format", "keras_v3", "model.keras"),
+ )
+ @pytest.mark.large
+ def test_saved_model(self, save_format, filename):
+ save_path = os.path.join(self.get_temp_dir(), filename)
+ self.masked_lm.save(save_path, save_format=save_format)
+ restored_model = keras.models.load_model(save_path)
+
+ # Check we got the real object back.
+ self.assertIsInstance(restored_model, XLMRobertaMaskedLM)
+
+ model_output = self.masked_lm(self.preprocessed_batch)
+ restored_output = restored_model(self.preprocessed_batch)
+
+ self.assertAllClose(model_output, restored_output)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py
index d679ab9ba3..13a8a7bac5 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py
@@ -100,10 +100,11 @@ def __init__(self, proto, **kwargs):
self.pad_token_id = 1 #
self.end_token_id = 2 #
self.unk_token_id = 3 #
+ self.mask_token_id = self.vocabulary_size() - 1 #
def vocabulary_size(self):
"""Get the size of the tokenizer vocabulary."""
- return super().vocabulary_size() + 1
+ return super().vocabulary_size() + 2
def get_vocabulary(self):
"""Get the size of the tokenizer vocabulary."""
@@ -112,10 +113,14 @@ def get_vocabulary(self):
tf.range(super().vocabulary_size())
)
)
- return self._vocabulary_prefix + vocabulary[3:]
+ return self._vocabulary_prefix + vocabulary[3:] + [""]
def id_to_token(self, id):
"""Convert an integer id to a string token."""
+
+ if id == self.mask_token_id:
+ return ""
+
if id < len(self._vocabulary_prefix):
return self._vocabulary_prefix[id]
@@ -139,25 +144,9 @@ def tokenize(self, inputs):
# Shift the tokens IDs right by one.
return tf.add(tokens, 1)
- def detokenize(self, inputs):
- if inputs.dtype == tf.string:
- return super().detokenize(inputs)
-
- # Shift the tokens IDs left by one.
- tokens = tf.subtract(inputs, 1)
-
- # Correct `unk_token_id`, `end_token_id`, `start_token_id`, respectively.
- # Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the
- # proto does not contain `pad_token_id`. This mapping of the pad token
- # is done automatically by the above subtraction.
- tokens = tf.where(tf.equal(tokens, self.unk_token_id - 1), 0, tokens)
- tokens = tf.where(tf.equal(tokens, self.end_token_id - 1), 2, tokens)
- tokens = tf.where(tf.equal(tokens, self.start_token_id - 1), 1, tokens)
-
- # Note: Even though we map `"" and `""` to the correct IDs,
- # the `detokenize` method will return empty strings for these tokens.
- # This is a vagary of the `sentencepiece` library.
- return super().detokenize(tokens)
+ def detokenize(self, ids):
+ ids = tf.ragged.boolean_mask(ids, tf.not_equal(ids, self.mask_token_id))
+ return super().detokenize(ids)
@classproperty
def presets(cls):
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
index 912403fc4e..6c216403c7 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
@@ -66,7 +66,7 @@ def test_unk_token(self):
def test_detokenize(self):
input_data = tf.constant([[4, 9, 5, 7]])
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, tf.constant(["brown round earth is"]))
def test_vocabulary(self):
vocabulary = self.tokenizer.get_vocabulary()
@@ -84,9 +84,10 @@ def test_vocabulary(self):
"▁is",
"▁quick",
"▁round",
+ "",
],
)
- self.assertEqual(self.tokenizer.vocabulary_size(), 11)
+ self.assertEqual(self.tokenizer.vocabulary_size(), 12)
def test_id_to_token(self):
print(self.tokenizer.id_to_token(9))