Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions keras_nlp/layers/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ class MultiSegmentPacker(keras.layers.Layer):
is always 0, and the segment id of each `end_value` is the segment that
precedes it.

If inputs are batched, inputs should be `tf.RaggedTensor`s with shape
`[batch_size, None]` and will be packed and converted to a dense tensor with
shape `[batch_size, sequence_length]`.

If inputs are unbatched, inputs should be dense rank-1 tensors of any shape,
and will be packed to shape `[sequence_length]`.
Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and
either rank-1 or rank-2.

Args:
sequence_length: The desired output length.
Expand Down Expand Up @@ -155,6 +151,13 @@ def _sanitize_inputs(self, inputs):
)
return inputs

def _convert_dense(self, x):
"""Converts inputs to rank 2 ragged tensors."""
if isinstance(x, tf.Tensor):
return tf.RaggedTensor.from_tensor(x)
else:
return x

def _trim_inputs(self, inputs):
"""Trim inputs to desired length."""
num_special_tokens = len(inputs) + 1
Expand Down Expand Up @@ -199,22 +202,20 @@ def _combine_inputs(self, segments):
def call(self, inputs):
inputs = self._sanitize_inputs(inputs)

# If rank 1, add a batch dim and convert to ragged.
# If rank 1, add a batch dim.
rank_1 = inputs[0].shape.rank == 1
if rank_1:
inputs = [tf.expand_dims(x, 0) for x in inputs]
inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs]
inputs = [self._convert_dense(x) for x in inputs]

segments = self._trim_inputs(inputs)
token_ids, segment_ids = self._combine_inputs(segments)

# Pad to dense tensor output.
shape = tf.cast([-1, self.sequence_length], "int64")
token_ids = token_ids.to_tensor(
shape=shape, default_value=self.pad_value
)
segment_ids = segment_ids.to_tensor(shape=shape)

# Remove the batch dim if added.
if rank_1:
token_ids = tf.squeeze(token_ids, 0)
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/layers/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_trim_multiple_inputs_waterfall(self):
)

def test_trim_batched_inputs_round_robin(self):
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]])
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
seq1 = tf.constant([["a", "b", "c"], ["a", "b", "c"]])
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
packer = MultiSegmentPacker(
7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin"
)
Expand All @@ -89,7 +89,7 @@ def test_trim_batched_inputs_round_robin(self):

def test_trim_batched_inputs_waterfall(self):
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
packer = MultiSegmentPacker(
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
)
Expand Down