Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions keras_nlp/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import copy

import tensorflow as tf

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.bart.bart_presets import backbone_presets
Expand Down Expand Up @@ -160,20 +162,26 @@ def __init__(
truncate=truncate,
sequence_length=encoder_sequence_length,
)
# The decoder is packed a bit differently; the format is as follows:
# `[end_token_id, start_token_id, tokens..., end_token_id, padding...]`.
# Hence, we pass `sequence_length - 1` to the packer.
self.decoder_packer = MultiSegmentPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=decoder_sequence_length,
sequence_length=decoder_sequence_length - 1,
)
# Maintain a private copy of `decoder_sequence_length` for config
# purposes.
self._decoder_sequence_length = decoder_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self.encoder_packer.sequence_length,
"decoder_sequence_length": self.decoder_packer.sequence_length,
"decoder_sequence_length": self._decoder_sequence_length,
"truncate": self.encoder_packer.truncate,
}
)
Expand Down Expand Up @@ -208,6 +216,18 @@ def call(self, x, y=None, sample_weight=None):
decoder_inputs = [self.tokenizer(segment) for segment in decoder_text]
decoder_token_ids, _ = self.decoder_packer(decoder_inputs)

# Append `end_token_id` to the beginning of `decoder_token_ids`.
input_is_1d = decoder_token_ids.shape.rank == 1
if input_is_1d:
decoder_token_ids = decoder_token_ids[tf.newaxis, :]
batch_size = tf.shape(decoder_token_ids)[0]
end_token_ids = tf.fill((batch_size, 1), self.tokenizer.end_token_id)
decoder_token_ids = tf.concat(
[end_token_ids, decoder_token_ids], axis=1
)
if input_is_1d:
decoder_token_ids = tf.squeeze(decoder_token_ids, axis=0)

x = {
"encoder_token_ids": encoder_token_ids,
"encoder_padding_mask": encoder_token_ids
Expand Down
18 changes: 9 additions & 9 deletions keras_nlp/models/bart/bart_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setUp(self):
merges=merges,
),
encoder_sequence_length=10,
decoder_sequence_length=8,
decoder_sequence_length=9,
)

def test_tokenize_strings(self):
Expand All @@ -71,10 +71,10 @@ def test_tokenize_strings(self):
output["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
)
self.assertAllEqual(
output["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1]
output["decoder_token_ids"], [2, 0, 7, 8, 9, 10, 11, 2, 1]
)
self.assertAllEqual(
output["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]
output["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 0]
)

def test_tokenize_list_of_strings(self):
Expand All @@ -91,10 +91,10 @@ def test_tokenize_list_of_strings(self):
output["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
)
self.assertAllEqual(
output["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
output["decoder_token_ids"], [[2, 0, 7, 8, 9, 10, 11, 2, 1]] * 4
)
self.assertAllEqual(
output["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
output["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0]] * 4
)

def test_tokenize_labeled_batch(self):
Expand All @@ -112,10 +112,10 @@ def test_tokenize_labeled_batch(self):
x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
)
self.assertAllEqual(
x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
x_out["decoder_token_ids"], [[2, 0, 7, 8, 9, 10, 11, 2, 1]] * 4
)
self.assertAllEqual(
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0]] * 4
)
self.assertAllEqual(y_out, y)
self.assertAllEqual(sw_out, sw)
Expand All @@ -137,10 +137,10 @@ def test_tokenize_labeled_dataset(self):
x_out["encoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4
)
self.assertAllEqual(
x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
x_out["decoder_token_ids"], [[2, 0, 7, 8, 9, 10, 11, 2, 1]] * 4
)
self.assertAllEqual(
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0]] * 4
)
self.assertAllEqual(y_out, y)
self.assertAllEqual(sw_out, sw)
Expand Down
18 changes: 9 additions & 9 deletions keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def setUp(self):
merges=merges,
),
encoder_sequence_length=10,
decoder_sequence_length=8,
decoder_sequence_length=9,
)

def test_tokenize_strings(self):
Expand All @@ -73,13 +73,13 @@ def test_tokenize_strings(self):
x_out["encoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
)
self.assertAllEqual(
x_out["decoder_token_ids"], [0, 7, 8, 9, 10, 11, 2, 1]
x_out["decoder_token_ids"], [2, 0, 7, 8, 9, 10, 11, 2, 1]
)
self.assertAllEqual(
x_out["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]
x_out["decoder_padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 0]
)
self.assertAllEqual(y_out, [7, 8, 9, 10, 11, 2, 1, 1])
self.assertAllEqual(sw_out, [1, 1, 1, 1, 1, 1, 0, 0])
self.assertAllEqual(y_out, [0, 7, 8, 9, 10, 11, 2, 1, 1])
self.assertAllEqual(sw_out, [1, 1, 1, 1, 1, 1, 1, 0, 0])

def test_tokenize_list_of_strings(self):
input_data = {
Expand All @@ -96,13 +96,13 @@ def test_tokenize_list_of_strings(self):
[[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]] * 4,
)
self.assertAllEqual(
x_out["decoder_token_ids"], [[0, 7, 8, 9, 10, 11, 2, 1]] * 4
x_out["decoder_token_ids"], [[2, 0, 7, 8, 9, 10, 11, 2, 1]] * 4
)
self.assertAllEqual(
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
x_out["decoder_padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 0]] * 4
)
self.assertAllEqual(y_out, [[7, 8, 9, 10, 11, 2, 1, 1]] * 4)
self.assertAllEqual(sw_out, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
self.assertAllEqual(y_out, [[0, 7, 8, 9, 10, 11, 2, 1, 1]] * 4)
self.assertAllEqual(sw_out, [[1, 1, 1, 1, 1, 1, 1, 0, 0]] * 4)

def test_error_multi_segment_input(self):
input_data = {
Expand Down