diff --git a/keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py index cd9a5f4cc5..ce2bba2fb5 100644 --- a/keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py @@ -12,24 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, -) -from keras_nlp.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, -) +from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.AlbertMaskedLMPreprocessor") -class AlbertMaskedLMPreprocessor( - AlbertTextClassifierPreprocessor, MaskedLMPreprocessor -): +class AlbertMaskedLMPreprocessor(MaskedLMPreprocessor): """ALBERT preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -120,82 +110,5 @@ class AlbertMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.cls_token_id, - self.tokenizer.sep_token_id, - self.tokenizer.pad_token_id, - ], - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config - - @tf_preprocessing_function - def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - token_ids, segment_ids, padding_mask = ( - x["token_ids"], - x["segment_ids"], - x["padding_mask"], - ) - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "segment_ids": segment_ids, - "padding_mask": padding_mask, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + backbone_cls = AlbertBackbone + tokenizer_cls = AlbertTokenizer diff --git a/keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py b/keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py index 260fb7109a..be533be259 100644 --- a/keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras - from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export( @@ -154,61 +148,3 @@ class AlbertTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = AlbertBackbone tokenizer_cls = AlbertTokenizer - - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.packer = None - self.truncate = truncate - self.sequence_length = sequence_length - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.packer = MultiSegmentPacker( - start_value=self.tokenizer.cls_token_id, - end_value=self.tokenizer.sep_token_id, - pad_value=self.tokenizer.pad_token_id, - truncate=self.truncate, - sequence_length=self.sequence_length, - ) - self.built = True - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - - @tf_preprocessing_function - def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, segment_ids = self.packer(x) - x = { - "token_ids": token_ids, - "segment_ids": segment_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/albert/albert_tokenizer.py b/keras_nlp/src/models/albert/albert_tokenizer.py index 53dd561af9..b96b11cac4 100644 --- a/keras_nlp/src/models/albert/albert_tokenizer.py +++ b/keras_nlp/src/models/albert/albert_tokenizer.py @@ -89,35 +89,12 @@ class AlbertTokenizer(SentencePieceTokenizer): backbone_cls = AlbertBackbone def __init__(self, proto, **kwargs): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "" - self.mask_token = "[MASK]" - + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("", "pad_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [ - self.cls_token, - self.sep_token, - self.pad_token, - self.mask_token, - ]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None - self.mask_token_id = None diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index b252761e43..059ff2a6e4 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -160,7 +160,7 @@ def from_preset( to save and load a pre-trained model. The `preset` can be passed as a one of: - 1. a built in preset identifier like `'bert_base_en'` + 1. a built-in preset identifier like `'bert_base_en'` 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` 3. a Hugging Face handle like `'hf://user/bert_base_en'` 4. a path to a local preset directory like `'./bert_base_en'` @@ -175,7 +175,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from all built-in presets available on the class. Args: - preset: string. A built in preset identifier, a Kaggle Models + preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. load_weights: bool. If `True`, the weights will be loaded into the model architecture. If `False`, the weights will be randomly diff --git a/keras_nlp/src/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/src/models/bart/bart_seq_2_seq_lm_preprocessor.py index 8908185e2b..315242511e 100644 --- a/keras_nlp/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/src/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -13,23 +13,15 @@ # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.bart.bart_backbone import BartBackbone +from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function - -try: - import tensorflow as tf -except ImportError: - tf = None @keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor") -class BartSeq2SeqLMPreprocessor(BartPreprocessor, Seq2SeqLMPreprocessor): +class BartSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): """BART Seq2Seq LM preprocessor. This layer is used as preprocessor for seq2seq tasks using the BART model. @@ -124,134 +116,20 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor, Seq2SeqLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - *, - encoder_sequence_length=None, - decoder_sequence_length=None, - # `sequence_length` is an alias for `decoder_sequence_length` - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` " - "from the provided input data, i.e., `x`. However, non-`None`" - "values have been passed for `y` or `sample_weight` or both. " - "These values will be ignored." - ) - - if encoder_sequence_length is None: - encoder_sequence_length = self.encoder_sequence_length - decoder_sequence_length = decoder_sequence_length or sequence_length - if decoder_sequence_length is None: - decoder_sequence_length = self.decoder_sequence_length - - x = super().call( - x, - encoder_sequence_length=encoder_sequence_length, - decoder_sequence_length=decoder_sequence_length + 1, - ) - decoder_token_ids = x.pop("decoder_token_ids") - decoder_padding_mask = x.pop("decoder_padding_mask") - - # The last token does not have a next token. Hence, we truncate it. - x = { - **x, - "decoder_token_ids": decoder_token_ids[..., :-1], - "decoder_padding_mask": decoder_padding_mask[..., :-1], - } - # Target `y` will be the decoder input sequence shifted one step to the - # left (i.e., the next token). - y = decoder_token_ids[..., 1:] - sample_weight = decoder_padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - *, - encoder_sequence_length=None, - # `sequence_length` is an alias for `decoder_sequence_length` - decoder_sequence_length=None, - sequence_length=None, - ): - """Convert encoder and decoder input strings to integer token inputs for generation. - - Similar to calling the layer for training, this method takes in a dict - containing `"encoder_text"` and `"decoder_text"`, with strings or tensor - strings for values, tokenizes and packs the input, and computes a - padding mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a tokenizer.end_token_id to the end of - the decoder sequence (as generation is expected to continue at the end - of the inputted decoder prompt). - """ - if not self.built: - self.build(None) - - if isinstance(x, dict): - encoder_text = x["encoder_text"] - decoder_text = x["decoder_text"] - else: - encoder_text = x - # Initialize empty prompt for the decoder. - decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") - - if encoder_sequence_length is None: - encoder_sequence_length = self.encoder_sequence_length - decoder_sequence_length = decoder_sequence_length or sequence_length - if decoder_sequence_length is None: - decoder_sequence_length = self.decoder_sequence_length - - # Tokenize and pack the encoder inputs. - encoder_token_ids = self.tokenizer(encoder_text) - encoder_token_ids, encoder_padding_mask = self.encoder_packer( - encoder_token_ids, - sequence_length=encoder_sequence_length, - ) - - # Tokenize and pack the decoder inputs. - decoder_token_ids = self.tokenizer(decoder_text) - decoder_token_ids, decoder_padding_mask = self.decoder_packer( - decoder_token_ids, - sequence_length=decoder_sequence_length, - add_end_value=False, - ) - - return { - "encoder_token_ids": encoder_token_ids, - "encoder_padding_mask": encoder_padding_mask, - "decoder_token_ids": decoder_token_ids, - "decoder_padding_mask": decoder_padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = ( - x["decoder_token_ids"], - x["decoder_padding_mask"], - ) - ids_to_strip = ( - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, + backbone_cls = BartBackbone + tokenizer_cls = BartTokenizer + + def build(self, input_shape): + super().build(input_shape) + # The decoder is packed a bit differently; the format is as follows: + # `[end_token_id, start_token_id, tokens..., end_token_id, padding...]`. + self.decoder_packer = StartEndPacker( + start_value=[ + self.tokenizer.end_token_id, + self.tokenizer.start_token_id, + ], + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.decoder_sequence_length, + return_padding_mask=True, ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/src/models/bart/bart_tokenizer.py b/keras_nlp/src/models/bart/bart_tokenizer.py index d7a5c57495..45dd2189b8 100644 --- a/keras_nlp/src/models/bart/bart_tokenizer.py +++ b/keras_nlp/src/models/bart/bart_tokenizer.py @@ -87,46 +87,11 @@ def __init__( merges=None, **kwargs, ): - self.start_token = "" - self.pad_token = "" - self.end_token = "" - + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ - self.start_token, - self.pad_token, - self.end_token, - ], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.pad_token, self.end_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.end_token_id = self.token_to_id(self.end_token) - else: - self.start_token_id = None - self.pad_token_id = None - self.end_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/bart/bart_tokenizer_test.py b/keras_nlp/src/models/bart/bart_tokenizer_test.py index 25a81768ed..3f01c6e712 100644 --- a/keras_nlp/src/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/src/models/bart/bart_tokenizer_test.py @@ -37,10 +37,9 @@ def test_tokenizer_basics(self): cls=BartTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/src/models/bert/bert_masked_lm_preprocessor.py b/keras_nlp/src/models/bert/bert_masked_lm_preprocessor.py index 29549587e5..ef060d2e46 100644 --- a/keras_nlp/src/models/bert/bert_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/bert/bert_masked_lm_preprocessor.py @@ -12,24 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, -) -from keras_nlp.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, -) +from keras_nlp.src.models.bert.bert_backbone import BertBackbone +from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.BertMaskedLMPreprocessor") -class BertMaskedLMPreprocessor( - BertTextClassifierPreprocessor, MaskedLMPreprocessor -): +class BertMaskedLMPreprocessor(MaskedLMPreprocessor): """BERT preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -123,83 +113,5 @@ class BertMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.cls_token_id, - self.tokenizer.sep_token_id, - self.tokenizer.pad_token_id, - ], - ) - - @tf_preprocessing_function - def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - - token_ids, padding_mask, segment_ids = ( - x["token_ids"], - x["padding_mask"], - x["segment_ids"], - ) - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "padding_mask": padding_mask, - "segment_ids": segment_ids, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config + backbone_cls = BertBackbone + tokenizer_cls = BertTokenizer diff --git a/keras_nlp/src/models/bert/bert_text_classifier_preprocessor.py b/keras_nlp/src/models/bert/bert_text_classifier_preprocessor.py index 357a86853b..5ed47b7694 100644 --- a/keras_nlp/src/models/bert/bert_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/bert/bert_text_classifier_preprocessor.py @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras - from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) from keras_nlp.src.models.bert.bert_backbone import BertBackbone from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export( @@ -132,61 +126,3 @@ class BertTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = BertBackbone tokenizer_cls = BertTokenizer - - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.packer = None - self.sequence_length = sequence_length - self.truncate = truncate - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.packer = MultiSegmentPacker( - start_value=self.tokenizer.cls_token_id, - end_value=self.tokenizer.sep_token_id, - pad_value=self.tokenizer.pad_token_id, - truncate=self.truncate, - sequence_length=self.sequence_length, - ) - self.built = True - - @tf_preprocessing_function - def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, segment_ids = self.packer(x) - x = { - "token_ids": token_ids, - "segment_ids": segment_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/bert/bert_tokenizer.py b/keras_nlp/src/models/bert/bert_tokenizer.py index b218114df7..da11efec21 100644 --- a/keras_nlp/src/models/bert/bert_tokenizer.py +++ b/keras_nlp/src/models/bert/bert_tokenizer.py @@ -77,41 +77,18 @@ def __init__( self, vocabulary=None, lowercase=False, - special_tokens_in_strings=False, **kwargs, ): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "[PAD]" - self.mask_token = "[MASK]" + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") super().__init__( vocabulary=vocabulary, lowercase=lowercase, - special_tokens=[ - self.cls_token, - self.sep_token, - self.pad_token, - self.mask_token, - ], - special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) - - def set_vocabulary(self, vocabulary): - super().set_vocabulary(vocabulary) - - if vocabulary is not None: - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None - self.mask_token_id = None - - def get_config(self): - config = super().get_config() - del config["special_tokens"] # Not configurable; set in __init__. - return config diff --git a/keras_nlp/src/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/src/models/bloom/bloom_causal_lm_preprocessor.py index f8ac0eeacf..f713e2c578 100644 --- a/keras_nlp/src/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/bloom/bloom_causal_lm_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.bloom.bloom_preprocessor import BloomPreprocessor +from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.BloomCausalLMPreprocessor") -class BloomCausalLMPreprocessor(BloomPreprocessor, CausalLMPreprocessor): +class BloomCausalLMPreprocessor(CausalLMPreprocessor): """BLOOM Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,87 +87,5 @@ class BloomCausalLMPreprocessor(BloomPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`BloomCausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = BloomBackbone + tokenizer_cls = BloomTokenizer diff --git a/keras_nlp/src/models/bloom/bloom_tokenizer.py b/keras_nlp/src/models/bloom/bloom_tokenizer.py index 6cd0a44d51..fb9debb628 100644 --- a/keras_nlp/src/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/src/models/bloom/bloom_tokenizer.py @@ -77,46 +77,11 @@ def __init__( merges=None, **kwargs, ): - self.start_token = "" - self.end_token = "" - self.pad_token = "" - + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ - self.start_token, - self.end_token, - self.pad_token, - ], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.end_token, self.pad_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = self.token_to_id(self.pad_token) - else: - self.start_token_id = None - self.end_token_id = None - self.pad_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/causal_lm_preprocessor.py b/keras_nlp/src/models/causal_lm_preprocessor.py index 7b0492f7cc..1713ce6566 100644 --- a/keras_nlp/src/models/causal_lm_preprocessor.py +++ b/keras_nlp/src/models/causal_lm_preprocessor.py @@ -11,8 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras + from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import strip_to_ragged +from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.CausalLMPreprocessor") @@ -66,4 +71,125 @@ class CausalLMPreprocessor(Preprocessor): ``` """ - # TODO: move common code down to this base class where possible. + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + @tf_preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + sequence_length = sequence_length or self.sequence_length + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @tf_preprocessing_function + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + @tf_preprocessing_function + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + ids_to_strip = self.tokenizer.special_token_ids + token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) + return self.tokenizer.detokenize(token_ids) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py b/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py index e9e866641a..8096e76e37 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py @@ -13,23 +13,20 @@ # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, +from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( + DebertaV3Backbone, ) -from keras_nlp.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, +from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( + DebertaV3Tokenizer, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.DebertaV3MaskedLMPreprocessor") -class DebertaV3MaskedLMPreprocessor( - DebertaV3TextClassifierPreprocessor, MaskedLMPreprocessor -): +class DebertaV3MaskedLMPreprocessor(MaskedLMPreprocessor): """DeBERTa preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -119,78 +116,13 @@ class DebertaV3MaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.cls_token_id, - self.tokenizer.sep_token_id, - self.tokenizer.pad_token_id, - ], - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config + backbone_cls = DebertaV3Backbone + tokenizer_cls = DebertaV3Tokenizer @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "padding_mask": padding_mask, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py b/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py index f74fe7ba13..6f1996d3b4 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py @@ -12,13 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, ) @@ -160,59 +156,10 @@ class DebertaV3TextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = DebertaV3Backbone tokenizer_cls = DebertaV3Tokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.packer = None - self.truncate = truncate - self.sequence_length = sequence_length - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.packer = MultiSegmentPacker( - start_value=self.tokenizer.cls_token_id, - end_value=self.tokenizer.sep_token_id, - pad_value=self.tokenizer.pad_token_id, - truncate=self.truncate, - sequence_length=self.sequence_length, - ) - self.built = True - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, _ = self.packer(x) - x = { - "token_ids": token_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py index bef7277697..168df159ec 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py @@ -101,37 +101,34 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): backbone_cls = DebertaV3Backbone def __init__(self, proto, **kwargs): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "[PAD]" + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") + # Handle mask separately as it's not always in the vocab. self.mask_token = "[MASK]" - + self.mask_token_id = None super().__init__(proto=proto, **kwargs) + @property + def special_tokens(self): + return super().special_tokens + [self.mask_token] + + @property + def special_token_ids(self): + return super().special_token_ids + [self.mask_token_id] + def set_proto(self, proto): super().set_proto(proto) if proto is not None: - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in super().get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - # If the mask token is not in the vocabulary, add it to the end of the - # vocabulary. if self.mask_token in super().get_vocabulary(): self.mask_token_id = super().token_to_id(self.mask_token) else: self.mask_token_id = super().vocabulary_size() else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None self.mask_token_id = None def vocabulary_size(self): @@ -142,6 +139,8 @@ def vocabulary_size(self): def get_vocabulary(self): sentence_piece_vocabulary = super().get_vocabulary() + if self.mask_token_id is None: + return sentence_piece_vocabulary if self.mask_token_id < super().vocabulary_size(): return sentence_piece_vocabulary return sentence_piece_vocabulary + ["[MASK]"] diff --git a/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py b/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py index b06c5f7100..9004fa06ba 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py @@ -13,23 +13,20 @@ # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, +from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, ) -from keras_nlp.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, +from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.DistilBertMaskedLMPreprocessor") -class DistilBertMaskedLMPreprocessor( - DistilBertTextClassifierPreprocessor, MaskedLMPreprocessor -): +class DistilBertMaskedLMPreprocessor(MaskedLMPreprocessor): """DistilBERT preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -123,77 +120,13 @@ class DistilBertMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.cls_token_id, - self.tokenizer.sep_token_id, - self.tokenizer.pad_token_id, - ], - ) + backbone_cls = DistilBertBackbone + tokenizer_cls = DistilBertTokenizer @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "padding_mask": padding_mask, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config diff --git a/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py b/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py index 445fdebc6f..90ebe7ef3b 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py @@ -16,9 +16,6 @@ import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( DistilBertBackbone, ) @@ -129,59 +126,10 @@ class DistilBertTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = DistilBertBackbone tokenizer_cls = DistilBertTokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.packer = None - self.sequence_length = sequence_length - self.truncate = truncate - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.packer = MultiSegmentPacker( - start_value=self.tokenizer.cls_token_id, - end_value=self.tokenizer.sep_token_id, - pad_value=self.tokenizer.pad_token_id, - truncate=self.truncate, - sequence_length=self.sequence_length, - ) - @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, _ = self.packer(x) - x = { - "token_ids": token_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py index 9cefa2418f..f99a41069b 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py @@ -81,41 +81,18 @@ def __init__( self, vocabulary, lowercase=False, - special_tokens_in_strings=False, **kwargs, ): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "[PAD]" - self.mask_token = "[MASK]" + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") super().__init__( vocabulary=vocabulary, lowercase=lowercase, - special_tokens=[ - self.cls_token, - self.sep_token, - self.pad_token, - self.mask_token, - ], - special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) - - def set_vocabulary(self, vocabulary): - super().set_vocabulary(vocabulary) - - if vocabulary is not None: - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None - self.mask_token_id = None - - def get_config(self): - config = super().get_config() - del config["special_tokens"] # Not configurable; set in __init__. - return config diff --git a/keras_nlp/src/models/electra/electra_tokenizer.py b/keras_nlp/src/models/electra/electra_tokenizer.py index 85835925d5..aa73e6e0b9 100644 --- a/keras_nlp/src/models/electra/electra_tokenizer.py +++ b/keras_nlp/src/models/electra/electra_tokenizer.py @@ -72,41 +72,18 @@ def __init__( self, vocabulary, lowercase=False, - special_tokens_in_strings=False, **kwargs, ): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "[PAD]" - self.mask_token = "[MASK]" + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") super().__init__( vocabulary=vocabulary, lowercase=lowercase, - special_tokens=[ - self.cls_token, - self.sep_token, - self.pad_token, - self.mask_token, - ], - special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) - - def set_vocabulary(self, vocabulary): - super().set_vocabulary(vocabulary) - - if vocabulary is not None: - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None - self.mask_token_id = None - - def get_config(self): - config = super().get_config() - del config["special_tokens"] # Not configurable; set in __init__. - return config diff --git a/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py b/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py index a450442ca6..7628f26ee4 100644 --- a/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py @@ -13,23 +13,16 @@ # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, -) -from keras_nlp.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, -) +from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.FNetMaskedLMPreprocessor") -class FNetMaskedLMPreprocessor( - FNetTextClassifierPreprocessor, MaskedLMPreprocessor -): +class FNetMaskedLMPreprocessor(MaskedLMPreprocessor): """FNet preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -125,79 +118,13 @@ class FNetMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None - - def build(self, input_shape): - super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.cls_token_id, - self.tokenizer.sep_token_id, - self.tokenizer.pad_token_id, - ], - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config + backbone_cls = FNetBackbone + tokenizer_cls = FNetTokenizer @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - x = super().call(x) - token_ids, segment_ids = ( - x["token_ids"], - x["segment_ids"], - ) - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "segment_ids": segment_ids, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # FNet has not padding mask. + del x["padding_mask"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py b/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py index 6eceb87106..c1f1f3f71e 100644 --- a/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py @@ -16,9 +16,6 @@ import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.text_classifier_preprocessor import ( @@ -127,59 +124,10 @@ class FNetTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = FNetBackbone tokenizer_cls = FNetTokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - self.tokenizer = tokenizer - self.packer = None - self.truncate = truncate - self.sequence_length = sequence_length - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.packer = MultiSegmentPacker( - start_value=self.tokenizer.cls_token_id, - end_value=self.tokenizer.sep_token_id, - pad_value=self.tokenizer.pad_token_id, - truncate=self.truncate, - sequence_length=self.sequence_length, - ) - self.built = True - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, segment_ids = self.packer(x) - x = { - "token_ids": token_ids, - "segment_ids": segment_ids, - } + # FNet has not padding mask. + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + del x["padding_mask"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/f_net/f_net_tokenizer.py b/keras_nlp/src/models/f_net/f_net_tokenizer.py index 233294e4d6..3b4d48fec3 100644 --- a/keras_nlp/src/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/src/models/f_net/f_net_tokenizer.py @@ -66,34 +66,12 @@ class FNetTokenizer(SentencePieceTokenizer): backbone_cls = FNetBackbone def __init__(self, proto, **kwargs): - self.cls_token = "[CLS]" - self.sep_token = "[SEP]" - self.pad_token = "" - self.mask_token = "[MASK]" + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("", "pad_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [ - self.cls_token, - self.sep_token, - self.pad_token, - self.mask_token, - ]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.cls_token_id = self.token_to_id(self.cls_token) - self.sep_token_id = self.token_to_id(self.sep_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.cls_token_id = None - self.sep_token_id = None - self.pad_token_id = None - self.mask_token_id = None diff --git a/keras_nlp/src/models/falcon/falcon_causal_lm_preprocessor.py b/keras_nlp/src/models/falcon/falcon_causal_lm_preprocessor.py index 62372affd3..2b6ae080ec 100644 --- a/keras_nlp/src/models/falcon/falcon_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/falcon/falcon_causal_lm_preprocessor.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer @keras_nlp_export("keras_nlp.models.FalconCausalLMPreprocessor") -class FalconCausalLMPreprocessor(FalconPreprocessor, CausalLMPreprocessor): +class FalconCausalLMPreprocessor(CausalLMPreprocessor): """Falcon Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,84 +86,5 @@ class FalconCausalLMPreprocessor(FalconPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`FalconCausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = (self.tokenizer.end_token_id,) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = FalconBackbone + tokenizer_cls = FalconTokenizer diff --git a/keras_nlp/src/models/falcon/falcon_tokenizer.py b/keras_nlp/src/models/falcon/falcon_tokenizer.py index 9d3263bafb..c59e0ce8da 100644 --- a/keras_nlp/src/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/src/models/falcon/falcon_tokenizer.py @@ -77,40 +77,11 @@ def __init__( merges=None, **kwargs, ): - # Falcon uses the same start as end token, i.e., "<|endoftext|>". - self.end_token = self.start_token = "<|endoftext|>" - + self._add_special_token("<|endoftext|>", "end_token") + self._add_special_token("<|endoftext|>", "start_token") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.end_token_id = self.token_to_id(self.end_token) - self.start_token_id = self.end_token_id - self.pad_token_id = 0 - else: - self.end_token_id = None - self.start_token_id = None - self.pad_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/src/models/gemma/gemma_causal_lm_preprocessor.py index f86f9ecf40..9a35f9baba 100644 --- a/keras_nlp/src/models/gemma/gemma_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/gemma/gemma_causal_lm_preprocessor.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.gemma.gemma_preprocessor import GemmaPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.src.models.gemma.gemma_tokenizer import GemmaTokenizer @keras_nlp_export("keras_nlp.models.GemmaCausalLMPreprocessor") -class GemmaCausalLMPreprocessor(GemmaPreprocessor, CausalLMPreprocessor): +class GemmaCausalLMPreprocessor(CausalLMPreprocessor): """Gemma Causal LM preprocessor. This preprocessing layer is meant for use with @@ -83,85 +79,5 @@ class GemmaCausalLMPreprocessor(GemmaPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`GemmaCausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess(self, x): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - self.tokenizer.pad_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = GemmaBackbone + tokenizer_cls = GemmaTokenizer diff --git a/keras_nlp/src/models/gemma/gemma_tokenizer.py b/keras_nlp/src/models/gemma/gemma_tokenizer.py index b8a7065ed4..b66fc8df68 100644 --- a/keras_nlp/src/models/gemma/gemma_tokenizer.py +++ b/keras_nlp/src/models/gemma/gemma_tokenizer.py @@ -86,26 +86,7 @@ class GemmaTokenizer(SentencePieceTokenizer): backbone_cls = GemmaBackbone def __init__(self, proto, **kwargs): - self.start_token = "" - self.end_token = "" - self.pad_token = "" - + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [self.end_token, self.pad_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = self.token_to_id(self.pad_token) - else: - self.start_token_id = None - self.end_token_id = None - self.pad_token_id = None diff --git a/keras_nlp/src/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/src/models/gpt2/gpt2_causal_lm_preprocessor.py index 0438ecc82a..1855031706 100644 --- a/keras_nlp/src/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer @keras_nlp_export("keras_nlp.models.GPT2CausalLMPreprocessor") -class GPT2CausalLMPreprocessor(GPT2Preprocessor, CausalLMPreprocessor): +class GPT2CausalLMPreprocessor(CausalLMPreprocessor): """GPT2 Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,84 +86,5 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = (self.tokenizer.end_token_id,) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = GPT2Backbone + tokenizer_cls = GPT2Tokenizer diff --git a/keras_nlp/src/models/gpt2/gpt2_tokenizer.py b/keras_nlp/src/models/gpt2/gpt2_tokenizer.py index e83e30a2a9..a55e86d716 100644 --- a/keras_nlp/src/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/src/models/gpt2/gpt2_tokenizer.py @@ -78,39 +78,11 @@ def __init__( **kwargs, ): # GPT2 uses the same start as end token, i.e., "<|endoftext|>". - self.end_token = self.start_token = "<|endoftext|>" - + self._add_special_token("<|endoftext|>", "end_token") + self._add_special_token("<|endoftext|>", "start_token") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.end_token_id = self.token_to_id(self.end_token) - self.start_token_id = self.end_token_id - self.pad_token_id = 0 - else: - self.end_token_id = None - self.start_token_id = None - self.pad_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py index a8bc5e8028..a3c5efe4b7 100644 --- a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py @@ -12,20 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_preprocessor import ( - GPTNeoXPreprocessor, -) -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone +from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer @keras_nlp_export("keras_nlp.models.GPTNeoXCausalLMPreprocessor") -class GPTNeoXCausalLMPreprocessor(GPTNeoXPreprocessor, CausalLMPreprocessor): +class GPTNeoXCausalLMPreprocessor(CausalLMPreprocessor): """GPT-NeoX Causal LM preprocessor. This preprocessing layer is meant for use with @@ -58,84 +52,5 @@ class GPTNeoXCausalLMPreprocessor(GPTNeoXPreprocessor, CausalLMPreprocessor): """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`GPTNeoXCausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = (self.tokenizer.end_token_id,) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = GPTNeoXBackbone + tokenizer_cls = GPTNeoXTokenizer diff --git a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py index c5390d1cd8..1fcb44b539 100644 --- a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +++ b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py @@ -56,39 +56,11 @@ def __init__( **kwargs, ): # GPTNeoX uses the same start as end token, i.e., "<|endoftext|>". - self.end_token = self.start_token = "<|endoftext|>" - + self._add_special_token("<|endoftext|>", "end_token") + self._add_special_token("<|endoftext|>", "start_token") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.end_token], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.end_token_id = self.token_to_id(self.end_token) - self.start_token_id = self.end_token_id - self.pad_token_id = 0 - else: - self.end_token_id = None - self.start_token_id = None - self.pad_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/src/models/llama/llama_causal_lm_preprocessor.py index effc87eb0e..ac7c444bbd 100644 --- a/keras_nlp/src/models/llama/llama_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/llama/llama_causal_lm_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.llama.llama_preprocessor import LlamaPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.src.models.llama.llama_tokenizer import LlamaTokenizer @keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") -class LlamaCausalLMPreprocessor(LlamaPreprocessor, CausalLMPreprocessor): +class LlamaCausalLMPreprocessor(CausalLMPreprocessor): """Llama Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,81 +87,5 @@ class LlamaCausalLMPreprocessor(LlamaPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`LlamaCausalLMPreprocessor` generates `y` and " - "`sample_weight` based on your input data, but your data " - "already contains `y` or `sample_weight`. Your `y` and " - "`sample_weight` will be ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = (self.tokenizer.end_token_id,) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = LlamaBackbone + tokenizer_cls = LlamaTokenizer diff --git a/keras_nlp/src/models/llama/llama_causal_lm_preprocessor_test.py b/keras_nlp/src/models/llama/llama_causal_lm_preprocessor_test.py index 36320cec8e..5cb902baed 100644 --- a/keras_nlp/src/models/llama/llama_causal_lm_preprocessor_test.py +++ b/keras_nlp/src/models/llama/llama_causal_lm_preprocessor_test.py @@ -42,11 +42,11 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output=( { - "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], - "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + "token_ids": [[1, 3, 8, 4, 6, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], }, - [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. - [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + [[3, 8, 4, 6, 2, 0, 0, 0]], + [[1, 1, 1, 1, 1, 0, 0, 0]], ), ) diff --git a/keras_nlp/src/models/llama/llama_tokenizer.py b/keras_nlp/src/models/llama/llama_tokenizer.py index eeb904b501..ad2240bda9 100644 --- a/keras_nlp/src/models/llama/llama_tokenizer.py +++ b/keras_nlp/src/models/llama/llama_tokenizer.py @@ -65,24 +65,7 @@ class LlamaTokenizer(SentencePieceTokenizer): backbone_cls = LlamaBackbone def __init__(self, proto, **kwargs): - self.start_token = "" - self.end_token = "" + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self.pad_token_id = 0 super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [self.start_token, self.end_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = 0 - else: - self.start_token_id = None - self.end_token_id = None - self.pad_token_id = None diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py index 984aff5d13..88bb06456b 100644 --- a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.llama3.llama3_preprocessor import Llama3Preprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer @keras_nlp_export("keras_nlp.models.Llama3CausalLMPreprocessor") -class Llama3CausalLMPreprocessor(Llama3Preprocessor, CausalLMPreprocessor): +class Llama3CausalLMPreprocessor(CausalLMPreprocessor): """Llama 3 Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,84 +86,5 @@ class Llama3CausalLMPreprocessor(Llama3Preprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`Llama3CausalLMPreprocessor` generates `y` and " - "`sample_weight` based on your input data, but your data " - "already contains `y` or `sample_weight`. Your `y` and " - "`sample_weight` will be ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.end_token_id, - self.tokenizer.start_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = Llama3Backbone + tokenizer_cls = Llama3Tokenizer diff --git a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py index 2b79bd0d4f..8c7d8a15bf 100644 --- a/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py +++ b/keras_nlp/src/models/llama3/llama3_causal_lm_preprocessor_test.py @@ -46,11 +46,11 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output=( { - "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0]], - "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + "token_ids": [[6, 1, 3, 4, 2, 5, 7, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], }, - [[1, 3, 4, 2, 5, 0, 0, 0]], # Pass through labels. - [[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights. + [[1, 3, 4, 2, 5, 7, 0, 0]], + [[1, 1, 1, 1, 1, 1, 0, 0]], ), ) diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer.py b/keras_nlp/src/models/llama3/llama3_tokenizer.py index 3900da5269..b4793312b3 100644 --- a/keras_nlp/src/models/llama3/llama3_tokenizer.py +++ b/keras_nlp/src/models/llama3/llama3_tokenizer.py @@ -32,40 +32,11 @@ def __init__( merges=None, **kwargs, ): - self.start_token = "<|begin_of_text|>" - self.end_token = "<|end_of_text|>" - + self._add_special_token("<|begin_of_text|>", "start_token") + self._add_special_token("<|end_of_text|>", "end_token") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[self.start_token, self.end_token], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - if self.end_token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{self.end_token}'` in the provided " - f"`vocabulary`. Please provide `'{self.end_token}'` in " - "your `vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = 0 - else: - self.end_token_id = None - self.start_token_id = None - self.pad_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/masked_lm_preprocessor.py b/keras_nlp/src/models/masked_lm_preprocessor.py index 840b329f54..491b715280 100644 --- a/keras_nlp/src/models/masked_lm_preprocessor.py +++ b/keras_nlp/src/models/masked_lm_preprocessor.py @@ -11,8 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras + from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( + MaskedLMMaskGenerator, +) +from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.MaskedLMPreprocessor") @@ -61,4 +70,87 @@ class MaskedLMPreprocessor(Preprocessor): ``` """ - # TODO: move common code down to this base class where possible. + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + mask_selection_rate=0.15, + mask_selection_length=96, + mask_token_rate=0.8, + random_token_rate=0.1, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.truncate = truncate + self.mask_selection_rate = mask_selection_rate + self.mask_selection_length = mask_selection_length + self.mask_token_rate = mask_token_rate + self.random_token_rate = random_token_rate + self.masker = None + + def build(self, input_shape): + super().build(input_shape) + # Defer masker creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=self.truncate, + sequence_length=self.sequence_length, + ) + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=self.mask_selection_rate, + mask_selection_length=self.mask_selection_length, + mask_token_rate=self.mask_token_rate, + random_token_rate=self.random_token_rate, + vocabulary_size=self.tokenizer.vocabulary_size(), + mask_token_id=self.tokenizer.mask_token_id, + unselectable_token_ids=self.tokenizer.special_token_ids, + ) + + @tf_preprocessing_function + def call(self, x, y=None, sample_weight=None): + x = x if isinstance(x, tuple) else (x,) + x = tuple(self.tokenizer(segment) for segment in x) + token_ids, segment_ids = self.packer(x) + padding_mask = token_ids != self.tokenizer.pad_token_id + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "padding_mask": padding_mask, + "segment_ids": segment_ids, + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "truncate": self.truncate, + "mask_selection_rate": self.mask_selection_rate, + "mask_selection_length": self.mask_selection_length, + "mask_token_rate": self.mask_token_rate, + "random_token_rate": self.random_token_rate, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor.py index 4444a0a55e..fe74de6c6b 100644 --- a/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor.py @@ -12,20 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.mistral.mistral_preprocessor import ( - MistralPreprocessor, -) -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer @keras_nlp_export("keras_nlp.models.MistralCausalLMPreprocessor") -class MistralCausalLMPreprocessor(MistralPreprocessor, CausalLMPreprocessor): +class MistralCausalLMPreprocessor(CausalLMPreprocessor): """Mistral Causal LM preprocessor. This preprocessing layer is meant for use with @@ -92,84 +86,5 @@ class MistralCausalLMPreprocessor(MistralPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`MistralCausalLMPreprocessor` generates `y` and " - "`sample_weight` based on your input data, but your data " - "already contains `y` or `sample_weight`. Your `y` and " - "`sample_weight` will be ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = MistralBackbone + tokenizer_cls = MistralTokenizer diff --git a/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor_test.py b/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor_test.py index 28a7abf597..e2b9bf185b 100644 --- a/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor_test.py +++ b/keras_nlp/src/models/mistral/mistral_causal_lm_preprocessor_test.py @@ -44,11 +44,11 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output=( { - "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], - "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + "token_ids": [[1, 3, 8, 4, 6, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], }, - [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. - [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + [[3, 8, 4, 6, 2, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights. ), ) diff --git a/keras_nlp/src/models/mistral/mistral_tokenizer.py b/keras_nlp/src/models/mistral/mistral_tokenizer.py index 332e436629..42895adc4f 100644 --- a/keras_nlp/src/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/src/models/mistral/mistral_tokenizer.py @@ -65,22 +65,7 @@ class MistralTokenizer(SentencePieceTokenizer): backbone_cls = MistralBackbone def __init__(self, proto, **kwargs): - self.start_token = "" - self.end_token = "" + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self.pad_token_id = 0 super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [self.start_token, self.end_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - else: - self.start_token_id = None - self.end_token_id = None diff --git a/keras_nlp/src/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/src/models/opt/opt_causal_lm_preprocessor.py index 66d8fee4cb..1823b1acc7 100644 --- a/keras_nlp/src/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/opt/opt_causal_lm_preprocessor.py @@ -11,19 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.opt.opt_backbone import OPTBackbone +from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer @keras_nlp_export("keras_nlp.models.OPTCausalLMPreprocessor") -class OPTCausalLMPreprocessor(OPTPreprocessor, CausalLMPreprocessor): +class OPTCausalLMPreprocessor(CausalLMPreprocessor): """OPT Causal LM preprocessor. This preprocessing layer is primarily meant to be used with @@ -91,87 +86,5 @@ class OPTCausalLMPreprocessor(OPTPreprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - if not self.built: - self.build(None) - - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.end_token_id, - self.tokenizer.pad_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = OPTBackbone + tokenizer_cls = OPTTokenizer diff --git a/keras_nlp/src/models/opt/opt_tokenizer.py b/keras_nlp/src/models/opt/opt_tokenizer.py index 8dee2a1235..0565c24404 100644 --- a/keras_nlp/src/models/opt/opt_tokenizer.py +++ b/keras_nlp/src/models/opt/opt_tokenizer.py @@ -77,46 +77,11 @@ def __init__( merges=None, **kwargs, ): - self.start_token = "" - self.pad_token = "" - self.end_token = "" - + self._add_special_token("", "end_token") + self._add_special_token("", "start_token") + self._add_special_token("", "pad_token") super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ - self.start_token, - self.pad_token, - self.end_token, - ], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.start_token, self.pad_token, self.end_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.end_token_id = self.token_to_id(self.end_token) - else: - self.start_token_id = None - self.pad_token_id = None - self.end_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py index 2c71b9bac4..764b570fa2 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( GemmaCausalLMPreprocessor, ) @@ -32,24 +30,10 @@ @keras_nlp_export("keras_nlp.models.PaliGemmaCausalLMPreprocessor") -class PaliGemmaCausalLMPreprocessor( - GemmaCausalLMPreprocessor, CausalLMPreprocessor -): +class PaliGemmaCausalLMPreprocessor(GemmaCausalLMPreprocessor): backbone_cls = PaliGemmaBackbone tokenizer_cls = PaliGemmaTokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - add_start_token=True, - add_end_token=True, - **kwargs, - ): - super().__init__( - tokenizer, sequence_length, add_start_token, add_end_token, **kwargs - ) - def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer # assets have loaded when restoring a saved model. @@ -70,15 +54,7 @@ def call( sample_weight=None, sequence_length=None, ): - if y is not None or sample_weight is not None: - logging.warning( - "`PaliGemmaCausalLMPreprocessor` generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) sequence_length = sequence_length or self.sequence_length - images, prompts, responses = x["images"], x["prompts"], x["responses"] prompts = self.tokenizer(prompts) responses = self.tokenizer(responses) diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py b/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py index 02670d04b8..fbdb0693d0 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py @@ -85,5 +85,3 @@ class PaliGemmaTokenizer(GemmaTokenizer): """ backbone_cls = PaliGemmaBackbone - - pass diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py index 71e1f7ada1..2dcbc778c5 100644 --- a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras -from absl import logging - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor -from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer @keras_nlp_export("keras_nlp.models.Phi3CausalLMPreprocessor") -class Phi3CausalLMPreprocessor(Phi3Preprocessor, CausalLMPreprocessor): +class Phi3CausalLMPreprocessor(CausalLMPreprocessor): """Phi3 Causal LM preprocessor. This preprocessing layer is meant for use with @@ -90,84 +86,5 @@ class Phi3CausalLMPreprocessor(Phi3Preprocessor, CausalLMPreprocessor): ``` """ - @tf_preprocessing_function - def call( - self, - x, - y=None, - sample_weight=None, - sequence_length=None, - ): - if y is not None or sample_weight is not None: - logging.warning( - "`Phi3CausalLMPreprocessor` generates `y` and " - "`sample_weight` based on your input data, but your data " - "already contains `y` or `sample_weight`. Your `y` and " - "`sample_weight` will be ignored." - ) - sequence_length = sequence_length or self.sequence_length - - x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @tf_preprocessing_function - def generate_preprocess( - self, - x, - sequence_length=None, - ): - """Convert strings to integer token input for generation. - - Similar to calling the layer for training, this method takes in strings - or tensor strings, tokenizes and packs the input, and computes a padding - mask masking all inputs not filled in with a padded value. - - Unlike calling the layer for training, this method does not compute - labels and will never append a `tokenizer.end_token_id` to the end of - the sequence (as generation is expected to continue at the end of the - inputted prompt). - """ - if not self.built: - self.build(None) - - x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) - return { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - - @tf_preprocessing_function - def generate_postprocess( - self, - x, - ): - """Convert integer token output to strings for generation. - - This method reverses `generate_preprocess()`, by first removing all - padding and start/end tokens, and then converting the integer sequence - back to a string. - """ - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - ids_to_strip = ( - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - ) - token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) - return self.tokenizer.detokenize(token_ids) + backbone_cls = Phi3Backbone + tokenizer_cls = Phi3Tokenizer diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py index b88ba11301..a09f268c10 100644 --- a/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py +++ b/keras_nlp/src/models/phi3/phi3_causal_lm_preprocessor_test.py @@ -43,13 +43,11 @@ def test_preprocessor_basics(self): input_data=self.input_data, expected_output=( { - "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 0]], - "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], + "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 15]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], }, - [[3, 5, 6, 4, 3, 9, 7, 11, 0, 0]], # Pass through labels. - [ - [1, 1, 1, 1, 1, 1, 1, 1, 0, 0] - ], # Pass through sample_weights. + [[3, 5, 6, 4, 3, 9, 7, 11, 15, 0]], + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], ), ) diff --git a/keras_nlp/src/models/phi3/phi3_tokenizer.py b/keras_nlp/src/models/phi3/phi3_tokenizer.py index 8b4ac20bd4..8a1db63442 100644 --- a/keras_nlp/src/models/phi3/phi3_tokenizer.py +++ b/keras_nlp/src/models/phi3/phi3_tokenizer.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone -from keras_nlp.src.models.phi3.phi3_presets import backbone_presets from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) -from keras_nlp.src.utils.python_utils import classproperty @keras_nlp_export( @@ -68,31 +64,7 @@ class Phi3Tokenizer(SentencePieceTokenizer): backbone_cls = Phi3Backbone def __init__(self, proto, **kwargs): - self.start_token = "" - self.end_token = "<|endoftext|>" + self._add_special_token("", "start_token") + self._add_special_token("<|endoftext|>", "end_token") + self.pad_token_id = 0 super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [self.start_token, self.end_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.start_token_id = self.token_to_id(self.start_token) - self.end_token_id = self.token_to_id(self.end_token) - # TODO: `pad_token` is `<|endoftext|>`, but setting it to `` - # for now, because of the way sampler works. sampler will think that - # `pad_token` is `end_token` and stop generation immediatly. - self.pad_token_id = 0 - else: - self.start_token_id = None - self.end_token_id = None - self.pad_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index 686d010bf8..ee79ad77d6 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -97,7 +97,7 @@ def from_preset( to save and load a pre-trained model. The `preset` can be passed as a one of: - 1. a built in preset identifier like `'bert_base_en'` + 1. a built-in preset identifier like `'bert_base_en'` 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` 3. a Hugging Face handle like `'hf://user/bert_base_en'` 4. a path to a local preset directory like `'./bert_base_en'` @@ -110,7 +110,7 @@ def from_preset( `keras_nlp.models.BertTextClassifierPreprocessor.from_preset()`. Args: - preset: string. A built in preset identifier, a Kaggle Models + preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. Examples: diff --git a/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py b/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py index 1585963c8b..4b3bf02b6c 100644 --- a/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py @@ -13,23 +13,19 @@ # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, +from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, -) +from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.RobertaMaskedLMPreprocessor") -class RobertaMaskedLMPreprocessor( - RobertaTextClassifierPreprocessor, MaskedLMPreprocessor -): +class RobertaMaskedLMPreprocessor(MaskedLMPreprocessor): """RoBERTa preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -126,78 +122,25 @@ class RobertaMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None + backbone_cls = RobertaBackbone + tokenizer_cls = RobertaTokenizer def build(self, input_shape): super().build(input_shape) - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - self.tokenizer.pad_token_id, - ], + # Roberta is doubles up the sep token, so we override build. + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sep_value=[self.tokenizer.end_token_id] * 2, + pad_value=self.tokenizer.pad_token_id, + truncate=self.truncate, + sequence_length=self.sequence_length, ) - self.built = True @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "padding_mask": padding_mask, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } - ) - return config diff --git a/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py b/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py index 1da00372d3..377cc2cb75 100644 --- a/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import keras from keras_nlp.src.api_export import keras_nlp_export @@ -140,23 +139,8 @@ class RobertaTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = RobertaBackbone tokenizer_cls = RobertaTokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - - self.tokenizer = tokenizer - self.packer = None - self.truncate = truncate - self.sequence_length = sequence_length - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. + # Roberta is doubles up the sep token, so we override build. self.packer = MultiSegmentPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, @@ -169,32 +153,8 @@ def build(self, input_shape): @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, _ = self.packer(x) - x = { - "token_ids": token_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/roberta/roberta_tokenizer.py b/keras_nlp/src/models/roberta/roberta_tokenizer.py index b99f4bedef..1097e3ba09 100644 --- a/keras_nlp/src/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/src/models/roberta/roberta_tokenizer.py @@ -82,55 +82,12 @@ def __init__( merges=None, **kwargs, ): - self.start_token = "" - self.pad_token = "" - self.end_token = "" - self.mask_token = "" - + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self._add_special_token("", "pad_token") + self._add_special_token("", "mask_token") super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=[ - self.start_token, - self.pad_token, - self.end_token, - self.mask_token, - ], **kwargs, ) - - def set_vocabulary_and_merges(self, vocabulary, merges): - super().set_vocabulary_and_merges(vocabulary, merges) - - if vocabulary is not None: - # Check for necessary special tokens. - for token in [ - self.start_token, - self.pad_token, - self.end_token, - self.mask_token, - ]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - - self.start_token_id = self.token_to_id(self.start_token) - self.pad_token_id = self.token_to_id(self.pad_token) - self.end_token_id = self.token_to_id(self.end_token) - self.mask_token_id = self.token_to_id(self.mask_token) - else: - self.start_token_id = None - self.pad_token_id = None - self.end_token_id = None - self.mask_token_id = None - - def get_config(self): - config = super().get_config() - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - return config diff --git a/keras_nlp/src/models/roberta/roberta_tokenizer_test.py b/keras_nlp/src/models/roberta/roberta_tokenizer_test.py index 86b4984c71..35c7b628e2 100644 --- a/keras_nlp/src/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/src/models/roberta/roberta_tokenizer_test.py @@ -38,9 +38,9 @@ def test_tokenizer_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py b/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py index bea771259a..8da697dc99 100644 --- a/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py @@ -11,12 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras + from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import strip_to_ragged +from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function + +try: + import tensorflow as tf +except ImportError: + tf = None @keras_nlp_export("keras_nlp.models.Seq2SeqLMPreprocessor") -class Seq2SeqLMPreprocessor(CausalLMPreprocessor): +class Seq2SeqLMPreprocessor(Preprocessor): """Base class for seq2seq language modeling preprocessing layers. `Seq2SeqLMPreprocessor` tasks wrap a `keras_nlp.tokenizer.Tokenizer` to @@ -71,4 +81,189 @@ class Seq2SeqLMPreprocessor(CausalLMPreprocessor): ``` """ - # TODO: move common code down to this base class where possible. + def __init__( + self, + tokenizer, + encoder_sequence_length=1024, + decoder_sequence_length=1024, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.encoder_packer = None + self.decoder_packer = None + self.encoder_sequence_length = encoder_sequence_length + self.decoder_sequence_length = decoder_sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.encoder_packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.encoder_sequence_length, + return_padding_mask=True, + ) + self.decoder_packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.decoder_sequence_length, + return_padding_mask=True, + ) + self.built = True + + @tf_preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, + ): + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + encoder_inputs = self.tokenizer(x["encoder_text"]) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_inputs, + sequence_length=encoder_sequence_length, + ) + decoder_inputs = self.tokenizer(x["decoder_text"]) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + ) + x = { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + # Target `y` will be the decoder input sequence shifted one step to the + # left (i.e., the next token). + y = decoder_token_ids[..., 1:] + sample_weight = decoder_padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + @tf_preprocessing_function + def generate_preprocess( + self, + x, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, + ): + """Convert encoder and decoder input strings to integer token inputs for generation. + + Similar to calling the layer for training, this method takes in a dict + containing `"encoder_text"` and `"decoder_text"`, with strings or tensor + strings for values, tokenizes and packs the input, and computes a + padding mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a tokenizer.end_token_id to the end of + the decoder sequence (as generation is expected to continue at the end + of the inputted decoder prompt). + """ + if not self.built: + self.build(None) + + if isinstance(x, dict): + encoder_text = x["encoder_text"] + decoder_text = x["decoder_text"] + else: + encoder_text = x + # Initialize empty prompt for the decoder. + decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + # Tokenize and pack the encoder inputs. + encoder_token_ids = self.tokenizer(encoder_text) + encoder_token_ids, encoder_padding_mask = self.encoder_packer( + encoder_token_ids, + sequence_length=encoder_sequence_length, + ) + + # Tokenize and pack the decoder inputs. + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_end_value=False, + ) + + return { + "encoder_token_ids": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + @tf_preprocessing_function + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = ( + x["decoder_token_ids"], + x["decoder_padding_mask"], + ) + ids_to_strip = self.tokenizer.special_token_ids + token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) + return self.tokenizer.detokenize(token_ids) + + @property + def encoder_sequence_length(self): + """The padded length of encoder input sequences.""" + return self._encoder_sequence_length + + @encoder_sequence_length.setter + def encoder_sequence_length(self, value): + self._encoder_sequence_length = value + if self.encoder_packer is not None: + self.encoder_packer.sequence_length = value + + @property + def decoder_sequence_length(self): + """The padded length of decoder input sequences.""" + return self._decoder_sequence_length + + @decoder_sequence_length.setter + def decoder_sequence_length(self, value): + self._decoder_sequence_length = value + if self.decoder_packer is not None: + self.decoder_packer.sequence_length = value + + @property + def sequence_length(self): + """Alias for `decoder_sequence_length`.""" + return self.decoder_sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self.decoder_sequence_length = value diff --git a/keras_nlp/src/models/seq_2_seq_lm_preprocessor_test.py b/keras_nlp/src/models/seq_2_seq_lm_preprocessor_test.py new file mode 100644 index 0000000000..b1d353ab9d --- /dev/null +++ b/keras_nlp/src/models/seq_2_seq_lm_preprocessor_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( + BartSeq2SeqLMPreprocessor, +) +from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_nlp.src.tests.test_case import TestCase + + +class TestSeq2SeqLMPreprocessor(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertTokenizer.presets.keys()) + bart_presets = set(BartPreprocessor.presets.keys()) + all_presets = set(Seq2SeqLMPreprocessor.presets.keys()) + self.assertTrue(bert_presets.isdisjoint(all_presets)) + self.assertTrue(bart_presets.issubset(all_presets)) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + Seq2SeqLMPreprocessor.from_preset("bart_base_en"), + BartSeq2SeqLMPreprocessor, + ) + self.assertIsInstance( + BartSeq2SeqLMPreprocessor.from_preset("bart_base_en"), + BartSeq2SeqLMPreprocessor, + ) + + @pytest.mark.large + def test_from_preset_with_sequence_length(self): + preprocessor = Seq2SeqLMPreprocessor.from_preset( + "bart_base_en", decoder_sequence_length=16 + ) + self.assertEqual(preprocessor.decoder_sequence_length, 16) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + # No loading on an incorrect class. + BartSeq2SeqLMPreprocessor.from_preset("bert_tiny_en_uncased") + with self.assertRaises(ValueError): + # No loading on a non-keras model. + BartSeq2SeqLMPreprocessor.from_preset("hf://spacy/en_core_web_sm") diff --git a/keras_nlp/src/models/t5/t5_tokenizer.py b/keras_nlp/src/models/t5/t5_tokenizer.py index 6bdab8e8a1..7a1987f275 100644 --- a/keras_nlp/src/models/t5/t5_tokenizer.py +++ b/keras_nlp/src/models/t5/t5_tokenizer.py @@ -83,26 +83,8 @@ class T5Tokenizer(SentencePieceTokenizer): backbone_cls = T5Backbone def __init__(self, proto, **kwargs): - self.end_token = "" - self.pad_token = "" - + # T5 uses the same start token as end token, i.e., "<\s>". + self._add_special_token("", "end_token") + self._add_special_token("", "start_token") + self._add_special_token("", "pad_token") super().__init__(proto=proto, **kwargs) - - def set_proto(self, proto): - super().set_proto(proto) - if proto is not None: - for token in [self.end_token, self.pad_token]: - if token not in self.get_vocabulary(): - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.end_token_id = self.token_to_id(self.end_token) - self.pad_token_id = self.token_to_id(self.pad_token) - # T5 uses the same start token as end token, i.e., "<\s>". - self.start_token_id = self.end_token_id - else: - self.end_token_id = None - self.pad_token_id = None - self.start_token_id = None diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index d5aa8eb0b8..7e494910f0 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -154,7 +154,7 @@ def from_preset( to save and load a pre-trained model. The `preset` can be passed as a one of: - 1. a built in preset identifier like `'bert_base_en'` + 1. a built-in preset identifier like `'bert_base_en'` 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` 3. a Hugging Face handle like `'hf://user/bert_base_en'` 4. a path to a local preset directory like `'./bert_base_en'` @@ -169,7 +169,7 @@ def from_preset( will be inferred from the config in the preset directory. Args: - preset: string. A built in preset identifier, a Kaggle Models + preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. load_weights: bool. If `True`, the weights will be loaded into the model architecture. If `False`, the weights will be randomly diff --git a/keras_nlp/src/models/text_classifier_preprocessor.py b/keras_nlp/src/models/text_classifier_preprocessor.py index c774d15a4a..6b8639d045 100644 --- a/keras_nlp/src/models/text_classifier_preprocessor.py +++ b/keras_nlp/src/models/text_classifier_preprocessor.py @@ -11,8 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras + from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.TextClassifierPreprocessor") @@ -73,4 +79,60 @@ class TextClassifierPreprocessor(Preprocessor): ``` """ - # TODO: move common code down to this base class where possible. + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.truncate = truncate + + def build(self, input_shape): + super().build(input_shape) + # Defer masker creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=self.truncate, + sequence_length=self.sequence_length, + ) + + @tf_preprocessing_function + def call(self, x, y=None, sample_weight=None): + x = x if isinstance(x, tuple) else (x,) + x = tuple(self.tokenizer(segment) for segment in x) + token_ids, segment_ids = self.packer(x) + x = { + "token_ids": token_ids, + "padding_mask": token_ids != self.tokenizer.pad_token_id, + "segment_ids": segment_ids, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "truncate": self.truncate, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/src/models/whisper/whisper_preprocessor.py b/keras_nlp/src/models/whisper/whisper_preprocessor.py index bc83139952..8a65e6d004 100644 --- a/keras_nlp/src/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/src/models/whisper/whisper_preprocessor.py @@ -203,10 +203,11 @@ def build(self, input_shape): bos_tokens += [self.tokenizer.language_tokens[self.language]] + special_token_dict = self.tokenizer._special_token_dict if self.task == "transcribe": - bos_tokens += [self.tokenizer.special_tokens["<|transcribe|>"]] + bos_tokens += [special_token_dict["<|transcribe|>"]] elif self.task == "translate": - bos_tokens += [self.tokenizer.special_tokens["<|translate|>"]] + bos_tokens += [special_token_dict["<|translate|>"]] else: if self.language is not None: logging.info( diff --git a/keras_nlp/src/models/whisper/whisper_tokenizer.py b/keras_nlp/src/models/whisper/whisper_tokenizer.py index 5bf33aec52..972c33502d 100644 --- a/keras_nlp/src/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/src/models/whisper/whisper_tokenizer.py @@ -102,20 +102,22 @@ def __init__( self.translate_token_id = special_tokens[self.translate_token] self.transcribe_token_id = special_tokens[self.transcribe_token] - self.special_tokens = special_tokens + self._special_token_dict = special_tokens self.language_tokens = language_tokens - - # TODO: Add language tokens to `unsplittable_tokens` once we figure - # out the performance issue with a large list. - unsplittable_tokens = list(special_tokens.keys()) - super().__init__( vocabulary=vocabulary, merges=merges, - unsplittable_tokens=unsplittable_tokens, **kwargs, ) + @property + def special_tokens(self): + return list(self._special_token_dict.keys()) + + @property + def special_token_ids(self): + return list(self._special_token_dict.values()) + def save_assets(self, dir_path): # TODO: whisper is currently mutating it's vocabulary before passing # it to the super class, so we need to restore the unmutated vocabulary @@ -148,7 +150,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.translate_token, self.transcribe_token, ]: - vocabulary[token] = self.special_tokens[token] + vocabulary[token] = self._special_token_dict[token] else: self._initial_vocabulary = None @@ -156,15 +158,9 @@ def set_vocabulary_and_merges(self, vocabulary, merges): def get_config(self): config = super().get_config() - - # In the constructor, we pass the list of special tokens to the - # `unsplittable_tokens` arg of the superclass' constructor. Hence, we - # delete it from the config here. - del config["unsplittable_tokens"] - config.update( { - "special_tokens": self.special_tokens, + "special_tokens": self._special_token_dict, "language_tokens": self.language_tokens, } ) diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py index 24930c413b..8feb7d674f 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py @@ -13,23 +13,23 @@ # limitations under the License. import keras -from absl import logging from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, +from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, +from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import ( + XLMRobertaBackbone, +) +from keras_nlp.src.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, ) from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLMPreprocessor") -class XLMRobertaMaskedLMPreprocessor( - XLMRobertaTextClassifierPreprocessor, MaskedLMPreprocessor -): +class XLMRobertaMaskedLMPreprocessor(MaskedLMPreprocessor): """XLM-RoBERTa preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling @@ -124,77 +124,26 @@ class XLMRobertaMaskedLMPreprocessor( ``` """ - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - mask_selection_rate=0.15, - mask_selection_length=96, - mask_token_rate=0.8, - random_token_rate=0.1, - **kwargs, - ): - super().__init__( - tokenizer, - sequence_length=sequence_length, - truncate=truncate, - **kwargs, - ) - self.mask_selection_rate = mask_selection_rate - self.mask_selection_length = mask_selection_length - self.mask_token_rate = mask_token_rate - self.random_token_rate = random_token_rate - self.masker = None + backbone_cls = XLMRobertaBackbone + tokenizer_cls = XLMRobertaTokenizer def build(self, input_shape): super().build(input_shape) - # Defer masker creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - self.masker = MaskedLMMaskGenerator( - mask_selection_rate=self.mask_selection_rate, - mask_selection_length=self.mask_selection_length, - mask_token_rate=self.mask_token_rate, - random_token_rate=self.random_token_rate, - vocabulary_size=self.tokenizer.vocabulary_size(), - mask_token_id=self.tokenizer.mask_token_id, - unselectable_token_ids=[ - self.tokenizer.start_token_id, - self.tokenizer.end_token_id, - self.tokenizer.pad_token_id, - ], - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "mask_selection_rate": self.mask_selection_rate, - "mask_selection_length": self.mask_selection_length, - "mask_token_rate": self.mask_token_rate, - "random_token_rate": self.random_token_rate, - } + # Roberta is doubles up the sep token, so we override build. + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sep_value=[self.tokenizer.end_token_id] * 2, + pad_value=self.tokenizer.pad_token_id, + truncate=self.truncate, + sequence_length=self.sequence_length, ) - return config + self.built = True @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - if y is not None or sample_weight is not None: - logging.warning( - f"{self.__class__.__name__} generates `y` and `sample_weight` " - "based on your input data, but your data already contains `y` " - "or `sample_weight`. Your `y` and `sample_weight` will be " - "ignored." - ) - - x = super().call(x) - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - masker_outputs = self.masker(token_ids) - x = { - "token_ids": masker_outputs["token_ids"], - "padding_mask": padding_mask, - "mask_positions": masker_outputs["mask_positions"], - } - y = masker_outputs["mask_ids"] - sample_weight = masker_outputs["mask_weights"] + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py index 82fa56c3ff..756b935dd0 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py @@ -157,23 +157,8 @@ def train_sentencepiece(ds, vocab_size): backbone_cls = XLMRobertaBackbone tokenizer_cls = XLMRobertaTokenizer - def __init__( - self, - tokenizer, - sequence_length=512, - truncate="round_robin", - **kwargs, - ): - super().__init__(**kwargs) - - self.tokenizer = tokenizer - self.packer = None - self.truncate = truncate - self.sequence_length = sequence_length - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. + # Roberta is doubles up the sep token, so we override build. self.packer = MultiSegmentPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, @@ -184,34 +169,10 @@ def build(self, input_shape): ) self.built = True - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "truncate": self.truncate, - } - ) - return config - @tf_preprocessing_function def call(self, x, y=None, sample_weight=None): - x = x if isinstance(x, tuple) else (x,) - x = tuple(self.tokenizer(segment) for segment in x) - token_ids, _ = self.packer(x) - x = { - "token_ids": token_ids, - "padding_mask": token_ids != self.tokenizer.pad_token_id, - } + output = super().call(x, y=y, sample_weight=sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) + # Backbone has no segment ID input. + del x["segment_ids"] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self._sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self._sequence_length = value - if self.packer is not None: - self.packer.sequence_length = value diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py index b5c7367416..1fba910f10 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -100,16 +100,21 @@ def train_sentencepiece(ds, vocab_size): backbone_cls = XLMRobertaBackbone def __init__(self, proto, **kwargs): - # List of special tokens. - self._vocabulary_prefix = ["", "", "", ""] + # Handle special tokens manually, as the tokenizer maps these tokens in + # a way that is not reflected in the vocabulary. + self.start_token, self.start_token_id = "", 0 + self.pad_token, self.pad_token_id = "", 1 + self.end_token, self.end_token_id = "", 2 + self.unk_token, self.unk_token_id = "", 3 + super().__init__(proto=proto, **kwargs) - # IDs of special tokens. - self.start_token_id = 0 # - self.pad_token_id = 1 # - self.end_token_id = 2 # - self.unk_token_id = 3 # + @property + def special_tokens(self): + return ["", "", "", ""] - super().__init__(proto=proto, **kwargs) + @property + def special_token_ids(self): + return [0, 1, 2, 3] def set_proto(self, proto): super().set_proto(proto) diff --git a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py index 2044a39b16..26096f5ec4 100644 --- a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py @@ -64,12 +64,17 @@ def create_alts_for_unsplittable_tokens(unsplittable_tokens): # Create alternates for all special tokens that will be not split during # tokenization. alts = [] - prefix = "Ĵ" - # Trim out splitters. - replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" - for token in unsplittable_tokens: - token = re.sub(replace_pattern, "", token) - alts.append(prefix + token) + for index in range(len(unsplittable_tokens)): + # Map unsplittable tokens to ĴA, ĴB, ĴC, etc. Which we assume will be + # a very uncommon string in any input data. We can't use a literal + # numeric counter here because we will split on all numbers. Ĵ is a + # random character we chose as it is likely to be unique. + prefix = "Ĵ" + digits = [int(d) for d in str(index)] + # Make numbers to uppercase characters so our token is still + # unsplittable. + suffix = "".join([chr(ord("A") + d) for d in digits]) + alts.append(prefix + suffix) return alts @@ -291,6 +296,8 @@ def __init__( super().__init__(dtype=dtype, **kwargs) self.sequence_length = sequence_length self.add_prefix_space = add_prefix_space + if unsplittable_tokens is None: + unsplittable_tokens = self.special_tokens self.unsplittable_tokens = unsplittable_tokens self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME] @@ -385,6 +392,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges): list(range(len(self.merges))), default=self.merge_ranks_lookup_default, ) + self._update_special_token_ids() def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" diff --git a/keras_nlp/src/tokenizers/byte_tokenizer.py b/keras_nlp/src/tokenizers/byte_tokenizer.py index d33ac285ce..594b2c2ffc 100644 --- a/keras_nlp/src/tokenizers/byte_tokenizer.py +++ b/keras_nlp/src/tokenizers/byte_tokenizer.py @@ -200,6 +200,7 @@ def __init__( self._char_lst = tf.constant( [i.tobytes() for i in np.arange(256, dtype=np.uint8)] ) + self._update_special_token_ids() def vocabulary_size(self): """Get the integer size of the tokenizer vocabulary.""" diff --git a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py index 14a86f8968..0d38998efe 100644 --- a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py @@ -176,6 +176,7 @@ def set_proto(self, proto): # Keras cannot serialize a bytestring, so we base64 encode the model # byte array as a string for saving. self.proto = proto_bytes + self._update_special_token_ids() def vocabulary_size(self): """Get the integer size of the tokenizer vocabulary.""" diff --git a/keras_nlp/src/tokenizers/tokenizer.py b/keras_nlp/src/tokenizers/tokenizer.py index 81c7b50176..f1304c8413 100644 --- a/keras_nlp/src/tokenizers/tokenizer.py +++ b/keras_nlp/src/tokenizers/tokenizer.py @@ -139,6 +139,55 @@ def token_to_id(self, token): f"{self.__class__.__name__}." ) + @property + def special_tokens(self): + """List all built-in special tokens for the tokenizer.""" + if not hasattr(self, "_special_token_attrs"): + return [] + tokens = set(getattr(self, a) for a in self._special_token_attrs) + return list(tokens) + + @property + def special_token_ids(self): + """List all built-in special token ids for the tokenizer.""" + if not hasattr(self, "_special_token_attrs"): + return [] + ids = set(getattr(self, f"{a}_id") for a in self._special_token_attrs) + if None in ids: + raise ValueError( + "Cannot access `special_token_ids` before a vocabulary has " + "been set on the tokenizer." + ) + return list(ids) + + def _add_special_token(self, token, name): + if not hasattr(self, "_special_token_attrs"): + self._special_token_attrs = [] + self._special_token_attrs.append(name) + setattr(self, name, token) + try: + id = self.token_to_id(token) + except (ValueError, AttributeError): + id = None + setattr(self, f"{name}_id", id) + + def _update_special_token_ids(self): + if not hasattr(self, "_special_token_attrs"): + return + vocabulary = self.get_vocabulary() + for attr in set(self._special_token_attrs): + token = getattr(self, attr) + if token not in vocabulary: + classname = self.__class__.__name__ + raise ValueError( + f"Cannot find special token `'{token}'` in the provided " + f"vocabulary for `{classname}`. Please ensure `'{token}'` " + "is in the provided vocabulary when creating the Tokenizer." + ) + for attr in self._special_token_attrs: + token = getattr(self, attr) + setattr(self, f"{attr}_id", self.token_to_id(token)) + def save_to_preset(self, preset_dir): """Save tokenizer to a preset directory. @@ -185,7 +234,7 @@ def from_preset( to save and load a pre-trained model. The `preset` can be passed as a one of: - 1. a built in preset identifier like `'bert_base_en'` + 1. a built-in preset identifier like `'bert_base_en'` 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` 3. a Hugging Face handle like `'hf://user/bert_base_en'` 4. a path to a local preset directory like `'./bert_base_en'` @@ -200,7 +249,7 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from will be inferred from the config in the preset directory. Args: - preset: string. A built in preset identifier, a Kaggle Models + preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. load_weights: bool. If `True`, the weights will be loaded into the model architecture. If `False`, the weights will be randomly diff --git a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py index 5c50a4db38..16115fa199 100644 --- a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py @@ -255,6 +255,7 @@ def __init__( self.input_encoding = input_encoding self.output_encoding = output_encoding self._vocabulary_size = vocabulary_size + self._update_special_token_ids() def get_config(self): config = super().get_config() diff --git a/keras_nlp/src/tokenizers/word_piece_tokenizer.py b/keras_nlp/src/tokenizers/word_piece_tokenizer.py index b7f46b3918..8336ffe83b 100644 --- a/keras_nlp/src/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/word_piece_tokenizer.py @@ -167,7 +167,7 @@ def pretokenize( if special_tokens_pattern is not None: # the idea here is to pass the special tokens regex to the split # function as delimiter regex pattern, so the input will be splitted - # by them, but also the function will treat each on of them as one + # by them, but also the function will treat each one of them as one # entity that shouldn't be splitted even if they have other # delimiter regex pattern inside them. then pass the special tokens # regex also as keep delimiter regex pattern, so they will @@ -264,12 +264,6 @@ class WordPieceTokenizer(tokenizer.Tokenizer): oov_token: str. The string value to substitute for an unknown token. It must be included in the vocab. Defaults to `"[UNK]"`. - special_tokens: list. A list of special tokens. when - `special_tokens_in_strings` is set to `True`, the tokenizer will map - every special token in the input strings to its id, even if these - special tokens contain characters that should be splitted before - tokenization such as punctuation. `special_tokens` must be included - in `vocabulary`. special_tokens_in_strings: bool. A bool to indicate if the tokenizer should expect special tokens in input strings that should be tokenized and mapped correctly to their ids. Defaults to False. @@ -370,19 +364,9 @@ def __init__( self.split_on_cjk = split_on_cjk self.suffix_indicator = suffix_indicator self.oov_token = oov_token - self.special_tokens = special_tokens - self._special_tokens_pattern = None - if self.split and special_tokens_in_strings: - # the idea here is to pass the special tokens regex to the - # split function as delimiter regex pattern, so the input will - # be splitted by them, but also the function will treat each on - # of them as one entity that shouldn't be splitted even if they - # have other delimiter regex pattern inside them. then pass the - # special tokens regex also as keep delimiter regex - # pattern, so they will not be removed. - self._special_tokens_pattern = get_special_tokens_pattern( - self.special_tokens - ) + self._init_special_tokens = special_tokens + self.special_tokens_in_strings = special_tokens_in_strings + self.set_vocabulary(vocabulary) self.file_assets = [VOCAB_FILENAME] @@ -424,16 +408,6 @@ def set_vocabulary(self, vocabulary): "the `oov_token` argument when creating the tokenizer." ) - # Check for special tokens in the vocabulary - if self.special_tokens is not None: - for token in self.special_tokens: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self._fast_word_piece = tf_text.FastWordpieceTokenizer( vocab=self.vocabulary, token_out_type=self.compute_dtype, @@ -442,6 +416,7 @@ def set_vocabulary(self, vocabulary): no_pretokenization=True, support_detokenization=True, ) + self._update_special_token_ids() def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" @@ -482,7 +457,8 @@ def get_config(self): "split": self.split, "suffix_indicator": self.suffix_indicator, "oov_token": self.oov_token, - "special_tokens": self.special_tokens, + "special_tokens": self._init_special_tokens, + "special_tokens_in_strings": self.special_tokens_in_strings, } ) return config @@ -498,13 +474,26 @@ def _check_vocabulary(self): def tokenize(self, inputs): self._check_vocabulary() unbatched = inputs.shape.rank == 0 + pattern = None + if self.split and self.special_tokens_in_strings: + # the idea here is to pass the special tokens regex to the + # split function as delimiter regex pattern, so the input will + # be splitted by them, but also the function will treat each one + # of them as one entity that shouldn't be splitted even if they + # have other delimiter regex pattern inside them. then pass the + # special tokens regex also as keep delimiter regex + # pattern, so they will not be removed. + special_tokens = self.special_tokens + if self._init_special_tokens: + special_tokens += self._init_special_tokens + pattern = get_special_tokens_pattern(special_tokens) inputs = pretokenize( inputs, self.lowercase, self.strip_accents, self.split, self.split_on_cjk, - self._special_tokens_pattern, + pattern, ) # Apply WordPiece and coerce shape for outputs. diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f5a6dc62ce..e7bd5a74db 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -123,7 +123,7 @@ def find_subclass(preset, cls, backbone_cls): directs = list(filter(lambda x: x in cls.__bases__, subclasses)) if len(directs) > 1: subclasses = directs - # Return the subclass that was registered first (prefer built in classes). + # Return the subclass that was registered first (prefer built-in classes). return subclasses[0]