-
Notifications
You must be signed in to change notification settings - Fork 301
Add an XLMRobertaMaskedLM task model #950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6823288
b132e5e
b94e3f0
87f6c2f
1912609
1d76df0
467e204
d3f751f
9798993
d2d0d9b
f227996
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""XLM-RoBERTa masked lm model.""" | ||
|
||
import copy | ||
|
||
from tensorflow import keras | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.layers.masked_lm_head import MaskedLMHead | ||
from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer | ||
from keras_nlp.models.task import Task | ||
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone | ||
from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( | ||
XLMRobertaMaskedLMPreprocessor, | ||
) | ||
from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets | ||
from keras_nlp.utils.keras_utils import is_xla_compatible | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLM") | ||
class XLMRobertaMaskedLM(Task): | ||
"""An end-to-end XLM-RoBERTa model for the masked language modeling task. | ||
|
||
This model will train XLM-RoBERTa on a masked language modeling task. | ||
The model will predict labels for a number of masked tokens in the | ||
input data. For usage of this model with pre-trained weights, see the | ||
`from_preset()` method. | ||
|
||
This model can optionally be configured with a `preprocessor` layer, in | ||
which case inputs can be raw string features during `fit()`, `predict()`, | ||
and `evaluate()`. Inputs will be tokenized and dynamically masked during | ||
training and evaluation. This is done by default when creating the model | ||
with `from_preset()`. | ||
|
||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. The underlying model is provided by a | ||
third party and subject to a separate license, available | ||
[here](https://github.com/facebookresearch/fairseq). | ||
|
||
Args: | ||
backbone: A `keras_nlp.models.XLMRobertaBackbone` instance. | ||
preprocessor: A `keras_nlp.models.XLMRobertaMaskedLMPreprocessor` or | ||
`None`. If `None`, this model will not apply preprocessing, and | ||
inputs should be preprocessed before calling the model. | ||
|
||
Example usage: | ||
|
||
Raw string inputs and pretrained backbone. | ||
```python | ||
# Create a dataset with raw string features. Labels are inferred. | ||
features = ["The quick brown fox jumped.", "I forgot my homework."] | ||
|
||
# Pretrained language model | ||
# on an MLM task. | ||
masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset( | ||
"xlm_roberta_base_multi", | ||
) | ||
masked_lm.fit(x=features, batch_size=2) | ||
``` | ||
|
||
# Re-compile (e.g., with a new learning rate). | ||
masked_lm.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=keras.optimizers.Adam(5e-5), | ||
jit_compile=True, | ||
) | ||
# Access backbone programatically (e.g., to change `trainable`). | ||
masked_lm.backbone.trainable = False | ||
# Fit again. | ||
masked_lm.fit(x=features, batch_size=2) | ||
``` | ||
|
||
Preprocessed integer data. | ||
```python | ||
# Create a preprocessed dataset where 0 is the mask token. | ||
features = { | ||
"token_ids": tf.constant( | ||
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) | ||
), | ||
"padding_mask": tf.constant( | ||
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) | ||
), | ||
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) | ||
} | ||
# Labels are the original masked values. | ||
labels = [[3, 5]] * 2 | ||
|
||
masked_lm = keras_nlp.models.XLMRobertaMaskedLM.from_preset( | ||
"xlm_roberta_base_multi", | ||
preprocessor=None, | ||
) | ||
|
||
masked_lm.fit(x=features, y=labels, batch_size=2) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
preprocessor=None, | ||
**kwargs, | ||
): | ||
inputs = { | ||
**backbone.input, | ||
"mask_positions": keras.Input( | ||
shape=(None,), dtype="int32", name="mask_positions" | ||
), | ||
} | ||
backbone_outputs = backbone(backbone.input) | ||
outputs = MaskedLMHead( | ||
vocabulary_size=backbone.vocabulary_size, | ||
embedding_weights=backbone.token_embedding.embeddings, | ||
intermediate_activation="gelu", | ||
kernel_initializer=roberta_kernel_initializer(), | ||
name="mlm_head", | ||
)(backbone_outputs, inputs["mask_positions"]) | ||
|
||
# Instantiate using Functional API Model constructor. | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
include_preprocessing=preprocessor is not None, | ||
**kwargs, | ||
) | ||
# All references to `self` below this line | ||
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
self.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=keras.optimizers.Adam(5e-5), | ||
weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), | ||
jit_compile=is_xla_compatible(self), | ||
) | ||
|
||
@classproperty | ||
def backbone_cls(cls): | ||
return XLMRobertaBackbone | ||
|
||
@classproperty | ||
def preprocessor_cls(cls): | ||
return XLMRobertaMaskedLMPreprocessor | ||
|
||
@classproperty | ||
def presets(cls): | ||
return copy.deepcopy(backbone_presets) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""XLM-RoBERTa masked language model preprocessor layer.""" | ||
|
||
from absl import logging | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator | ||
from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import ( | ||
XLMRobertaPreprocessor, | ||
) | ||
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLMPreprocessor") | ||
class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor): | ||
"""XLM-RoBERTa preprocessing for the masked language modeling task. | ||
|
||
This preprocessing layer will prepare inputs for a masked language modeling | ||
task. It is primarily intended for use with the | ||
`keras_nlp.models.XLMRobertaMaskedLM` task model. Preprocessing will occur in | ||
multiple steps. | ||
|
||
1. Tokenize any number of input segments using the `tokenizer`. | ||
2. Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and | ||
`"<pad>"` tokens, i.e., adding a single `"<s>"` at the start of the | ||
entire sequence, `"</s></s>"` between each segment, | ||
and a `"</s>"` at the end of the entire sequence. | ||
3. Randomly select non-special tokens to mask, controlled by | ||
`mask_selection_rate`. | ||
4. Construct a `(x, y, sample_weight)` tuple suitable for training with a | ||
`keras_nlp.models.XLMRobertaMaskedLM` task model. | ||
|
||
Args: | ||
tokenizer: A `keras_nlp.models.XLMRobertaTokenizer` instance. | ||
sequence_length: int. The length of the packed inputs. | ||
truncate: string. The algorithm to truncate a list of batched segments | ||
to fit within `sequence_length`. The value can be either | ||
`round_robin` or `waterfall`: | ||
- `"round_robin"`: Available space is assigned one token at a | ||
time in a round-robin fashion to the inputs that still need | ||
some, until the limit is reached. | ||
- `"waterfall"`: The allocation of the budget is done using a | ||
"waterfall" algorithm that allocates quota in a | ||
left-to-right manner and fills up the buckets until we run | ||
out of budget. It supports an arbitrary number of segments. | ||
mask_selection_rate: float. The probability an input token will be | ||
dynamically masked. | ||
mask_selection_length: int. The maximum number of masked tokens | ||
in a given sample. | ||
mask_token_rate: float. The probability the a selected token will be | ||
replaced with the mask token. | ||
random_token_rate: float. The probability the a selected token will be | ||
replaced with a random token from the vocabulary. A selected token | ||
will be left as is with probability | ||
`1 - mask_token_rate - random_token_rate`. | ||
|
||
Call arguments: | ||
x: A tensor of single string sequences, or a tuple of multiple | ||
tensor sequences to be packed together. Inputs may be batched or | ||
unbatched. For single sequences, raw python inputs will be converted | ||
to tensors. For multiple sequences, pass tensors directly. | ||
y: Label data. Should always be `None` as the layer generates labels. | ||
sample_weight: Label weights. Should always be `None` as the layer | ||
generates label weights. | ||
|
||
Examples: | ||
```python | ||
# Load the preprocessor from a preset. | ||
preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset( | ||
"xlm_roberta_base_multi" | ||
) | ||
|
||
# Tokenize and mask a single sentence. | ||
preprocessor("The quick brown fox jumped.") | ||
# Tokenize and mask a batch of single sentences. | ||
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
# Tokenize and mask sentence pairs. | ||
# In this case, always convert input to tensors before calling the layer. | ||
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
second = tf.constant(["The fox tripped.", "Oh look, a whale."]) | ||
preprocessor((first, second)) | ||
``` | ||
|
||
Mapping with `tf.data.Dataset`. | ||
```python | ||
preprocessor = keras_nlp.models.XLMRobertaMaskedLMPreprocessor.from_preset( | ||
"xlm_roberta_base_multi" | ||
) | ||
first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) | ||
second = tf.constant(["The fox tripped.", "Oh look, a whale."]) | ||
|
||
# Map single sentences. | ||
ds = tf.data.Dataset.from_tensor_slices(first) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
||
# Map sentence pairs. | ||
ds = tf.data.Dataset.from_tensor_slices((first, second)) | ||
# Watch out for tf.data's default unpacking of tuples here! | ||
# Best to invoke the `preprocessor` directly in this case. | ||
ds = ds.map( | ||
lambda first, second: preprocessor(x=(first, second)), | ||
num_parallel_calls=tf.data.AUTOTUNE, | ||
) | ||
``` | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer, | ||
sequence_length=512, | ||
truncate="round_robin", | ||
mask_selection_rate=0.15, | ||
mask_selection_length=96, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure, is 96 a good default for XLM-R MLM? @abheesht17 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be good! Max sequence length of the model is 512, mask rate above is 0.15, so 76 average. 96 should cover almost all samples. |
||
mask_token_rate=0.8, | ||
random_token_rate=0.1, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
tokenizer, | ||
sequence_length=sequence_length, | ||
truncate=truncate, | ||
**kwargs, | ||
) | ||
|
||
self.masker = MaskedLMMaskGenerator( | ||
mask_selection_rate=mask_selection_rate, | ||
mask_selection_length=mask_selection_length, | ||
mask_token_rate=mask_token_rate, | ||
random_token_rate=random_token_rate, | ||
vocabulary_size=tokenizer.vocabulary_size(), | ||
mask_token_id=tokenizer.mask_token_id, | ||
unselectable_token_ids=[ | ||
tokenizer.start_token_id, | ||
tokenizer.end_token_id, | ||
tokenizer.pad_token_id, | ||
], | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"mask_selection_rate": self.masker.mask_selection_rate, | ||
"mask_selection_length": self.masker.mask_selection_length, | ||
"mask_token_rate": self.masker.mask_token_rate, | ||
"random_token_rate": self.masker.random_token_rate, | ||
} | ||
) | ||
return config | ||
|
||
def call(self, x, y=None, sample_weight=None): | ||
if y is not None or sample_weight is not None: | ||
logging.warning( | ||
f"{self.__class__.__name__} generates `y` and `sample_weight` " | ||
"based on your input data, but your data already contains `y` " | ||
"or `sample_weight`. Your `y` and `sample_weight` will be " | ||
"ignored." | ||
) | ||
|
||
x = super().call(x) | ||
token_ids, padding_mask = x["token_ids"], x["padding_mask"] | ||
masker_outputs = self.masker(token_ids) | ||
x = { | ||
"token_ids": masker_outputs["token_ids"], | ||
"padding_mask": padding_mask, | ||
"mask_positions": masker_outputs["mask_positions"], | ||
} | ||
y = masker_outputs["mask_ids"] | ||
sample_weight = masker_outputs["mask_weights"] | ||
return pack_x_y_sample_weight(x, y, sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we have this style elsewhere (just add one empty line and the sentence below)...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get that, all other files have similar formatting
https://github.com/keras-team/keras-nlp/blob/aaa6d230a42783eae6d9d695e1cc5de2c3b68f8b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py#L73-L80
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I was thinking this, but we still aren't consistent in the new style...
https://github.com/keras-team/keras-nlp/blob/aaa6d230a42783eae6d9d695e1cc5de2c3b68f8b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py#L75-L81