Skip to content

Commit 9250f79

Browse files
committed
Move preprocessing to base classes
I think this will overall be a nice simplification for maintenance. Push whatever logic we can down onto the base preprocessing classes. Saves a lot of code. To assist with this, I am adding a `special_tokens` property to tokenizers, which I think will be useful anyway.
1 parent 5234a81 commit 9250f79

File tree

66 files changed

+943
-2681
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+943
-2681
lines changed

keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py

Lines changed: 5 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import keras
16-
from absl import logging
17-
1815
from keras_nlp.src.api_export import keras_nlp_export
19-
from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import (
20-
MaskedLMMaskGenerator,
21-
)
22-
from keras_nlp.src.models.albert.albert_text_classifier_preprocessor import (
23-
AlbertTextClassifierPreprocessor,
24-
)
16+
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
17+
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
2518
from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
26-
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
2719

2820

2921
@keras_nlp_export("keras_nlp.models.AlbertMaskedLMPreprocessor")
30-
class AlbertMaskedLMPreprocessor(
31-
AlbertTextClassifierPreprocessor, MaskedLMPreprocessor
32-
):
22+
class AlbertMaskedLMPreprocessor(MaskedLMPreprocessor):
3323
"""ALBERT preprocessing for the masked language modeling task.
3424
3525
This preprocessing layer will prepare inputs for a masked language modeling
@@ -120,82 +110,5 @@ class AlbertMaskedLMPreprocessor(
120110
```
121111
"""
122112

123-
def __init__(
124-
self,
125-
tokenizer,
126-
sequence_length=512,
127-
truncate="round_robin",
128-
mask_selection_rate=0.15,
129-
mask_selection_length=96,
130-
mask_token_rate=0.8,
131-
random_token_rate=0.1,
132-
**kwargs,
133-
):
134-
super().__init__(
135-
tokenizer,
136-
sequence_length=sequence_length,
137-
truncate=truncate,
138-
**kwargs,
139-
)
140-
self.mask_selection_rate = mask_selection_rate
141-
self.mask_selection_length = mask_selection_length
142-
self.mask_token_rate = mask_token_rate
143-
self.random_token_rate = random_token_rate
144-
self.masker = None
145-
146-
def build(self, input_shape):
147-
super().build(input_shape)
148-
# Defer masker creation to `build()` so that we can be sure tokenizer
149-
# assets have loaded when restoring a saved model.
150-
self.masker = MaskedLMMaskGenerator(
151-
mask_selection_rate=self.mask_selection_rate,
152-
mask_selection_length=self.mask_selection_length,
153-
mask_token_rate=self.mask_token_rate,
154-
random_token_rate=self.random_token_rate,
155-
vocabulary_size=self.tokenizer.vocabulary_size(),
156-
mask_token_id=self.tokenizer.mask_token_id,
157-
unselectable_token_ids=[
158-
self.tokenizer.cls_token_id,
159-
self.tokenizer.sep_token_id,
160-
self.tokenizer.pad_token_id,
161-
],
162-
)
163-
164-
def get_config(self):
165-
config = super().get_config()
166-
config.update(
167-
{
168-
"mask_selection_rate": self.mask_selection_rate,
169-
"mask_selection_length": self.mask_selection_length,
170-
"mask_token_rate": self.mask_token_rate,
171-
"random_token_rate": self.random_token_rate,
172-
}
173-
)
174-
return config
175-
176-
@tf_preprocessing_function
177-
def call(self, x, y=None, sample_weight=None):
178-
if y is not None or sample_weight is not None:
179-
logging.warning(
180-
f"{self.__class__.__name__} generates `y` and `sample_weight` "
181-
"based on your input data, but your data already contains `y` "
182-
"or `sample_weight`. Your `y` and `sample_weight` will be "
183-
"ignored."
184-
)
185-
186-
x = super().call(x)
187-
token_ids, segment_ids, padding_mask = (
188-
x["token_ids"],
189-
x["segment_ids"],
190-
x["padding_mask"],
191-
)
192-
masker_outputs = self.masker(token_ids)
193-
x = {
194-
"token_ids": masker_outputs["token_ids"],
195-
"segment_ids": segment_ids,
196-
"padding_mask": padding_mask,
197-
"mask_positions": masker_outputs["mask_positions"],
198-
}
199-
y = masker_outputs["mask_ids"]
200-
sample_weight = masker_outputs["mask_weights"]
201-
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
113+
backbone_cls = AlbertBackbone
114+
tokenizer_cls = AlbertTokenizer

keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import keras
16-
1715
from keras_nlp.src.api_export import keras_nlp_export
18-
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
19-
MultiSegmentPacker,
20-
)
2116
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
2217
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
2318
from keras_nlp.src.models.text_classifier_preprocessor import (
2419
TextClassifierPreprocessor,
2520
)
26-
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
2721

2822

2923
@keras_nlp_export(
@@ -154,61 +148,3 @@ class AlbertTextClassifierPreprocessor(TextClassifierPreprocessor):
154148

155149
backbone_cls = AlbertBackbone
156150
tokenizer_cls = AlbertTokenizer
157-
158-
def __init__(
159-
self,
160-
tokenizer,
161-
sequence_length=512,
162-
truncate="round_robin",
163-
**kwargs,
164-
):
165-
super().__init__(**kwargs)
166-
self.tokenizer = tokenizer
167-
self.packer = None
168-
self.truncate = truncate
169-
self.sequence_length = sequence_length
170-
171-
def build(self, input_shape):
172-
# Defer packer creation to `build()` so that we can be sure tokenizer
173-
# assets have loaded when restoring a saved model.
174-
self.packer = MultiSegmentPacker(
175-
start_value=self.tokenizer.cls_token_id,
176-
end_value=self.tokenizer.sep_token_id,
177-
pad_value=self.tokenizer.pad_token_id,
178-
truncate=self.truncate,
179-
sequence_length=self.sequence_length,
180-
)
181-
self.built = True
182-
183-
def get_config(self):
184-
config = super().get_config()
185-
config.update(
186-
{
187-
"sequence_length": self.sequence_length,
188-
"truncate": self.truncate,
189-
}
190-
)
191-
return config
192-
193-
@tf_preprocessing_function
194-
def call(self, x, y=None, sample_weight=None):
195-
x = x if isinstance(x, tuple) else (x,)
196-
x = tuple(self.tokenizer(segment) for segment in x)
197-
token_ids, segment_ids = self.packer(x)
198-
x = {
199-
"token_ids": token_ids,
200-
"segment_ids": segment_ids,
201-
"padding_mask": token_ids != self.tokenizer.pad_token_id,
202-
}
203-
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
204-
205-
@property
206-
def sequence_length(self):
207-
"""The padded length of model input sequences."""
208-
return self._sequence_length
209-
210-
@sequence_length.setter
211-
def sequence_length(self, value):
212-
self._sequence_length = value
213-
if self.packer is not None:
214-
self.packer.sequence_length = value

keras_nlp/src/models/albert/albert_tokenizer.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -89,35 +89,12 @@ class AlbertTokenizer(SentencePieceTokenizer):
8989
backbone_cls = AlbertBackbone
9090

9191
def __init__(self, proto, **kwargs):
92-
self.cls_token = "[CLS]"
93-
self.sep_token = "[SEP]"
94-
self.pad_token = "<pad>"
95-
self.mask_token = "[MASK]"
96-
92+
self._add_special_token("[CLS]", "cls_token")
93+
self._add_special_token("[SEP]", "sep_token")
94+
self._add_special_token("<pad>", "pad_token")
95+
self._add_special_token("[MASK]", "mask_token")
96+
# Also add `tokenizer.start_token` and `tokenizer.end_token` for
97+
# compatibility with other tokenizers.
98+
self._add_special_token("[CLS]", "start_token")
99+
self._add_special_token("[SEP]", "end_token")
97100
super().__init__(proto=proto, **kwargs)
98-
99-
def set_proto(self, proto):
100-
super().set_proto(proto)
101-
if proto is not None:
102-
for token in [
103-
self.cls_token,
104-
self.sep_token,
105-
self.pad_token,
106-
self.mask_token,
107-
]:
108-
if token not in self.get_vocabulary():
109-
raise ValueError(
110-
f"Cannot find token `'{token}'` in the provided "
111-
f"`vocabulary`. Please provide `'{token}'` in your "
112-
"`vocabulary` or use a pretrained `vocabulary` name."
113-
)
114-
115-
self.cls_token_id = self.token_to_id(self.cls_token)
116-
self.sep_token_id = self.token_to_id(self.sep_token)
117-
self.pad_token_id = self.token_to_id(self.pad_token)
118-
self.mask_token_id = self.token_to_id(self.mask_token)
119-
else:
120-
self.cls_token_id = None
121-
self.sep_token_id = None
122-
self.pad_token_id = None
123-
self.mask_token_id = None

0 commit comments

Comments
 (0)