diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index a6cb001de0..20d2448d1c 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -40,12 +40,8 @@ class MultiSegmentPacker(keras.layers.Layer): is always 0, and the segment id of each `end_value` is the segment that precedes it. - If inputs are batched, inputs should be `tf.RaggedTensor`s with shape - `[batch_size, None]` and will be packed and converted to a dense tensor with - shape `[batch_size, sequence_length]`. - - If inputs are unbatched, inputs should be dense rank-1 tensors of any shape, - and will be packed to shape `[sequence_length]`. + Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and + either rank-1 or rank-2. Args: sequence_length: The desired output length. @@ -155,6 +151,13 @@ def _sanitize_inputs(self, inputs): ) return inputs + def _convert_dense(self, x): + """Converts inputs to rank 2 ragged tensors.""" + if isinstance(x, tf.Tensor): + return tf.RaggedTensor.from_tensor(x) + else: + return x + def _trim_inputs(self, inputs): """Trim inputs to desired length.""" num_special_tokens = len(inputs) + 1 @@ -199,22 +202,20 @@ def _combine_inputs(self, segments): def call(self, inputs): inputs = self._sanitize_inputs(inputs) - # If rank 1, add a batch dim and convert to ragged. + # If rank 1, add a batch dim. rank_1 = inputs[0].shape.rank == 1 if rank_1: inputs = [tf.expand_dims(x, 0) for x in inputs] - inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs] + inputs = [self._convert_dense(x) for x in inputs] segments = self._trim_inputs(inputs) token_ids, segment_ids = self._combine_inputs(segments) - # Pad to dense tensor output. shape = tf.cast([-1, self.sequence_length], "int64") token_ids = token_ids.to_tensor( shape=shape, default_value=self.pad_value ) segment_ids = segment_ids.to_tensor(shape=shape) - # Remove the batch dim if added. if rank_1: token_ids = tf.squeeze(token_ids, 0) diff --git a/keras_nlp/layers/multi_segment_packer_test.py b/keras_nlp/layers/multi_segment_packer_test.py index a03b20cdb9..3655ac7fab 100644 --- a/keras_nlp/layers/multi_segment_packer_test.py +++ b/keras_nlp/layers/multi_segment_packer_test.py @@ -67,8 +67,8 @@ def test_trim_multiple_inputs_waterfall(self): ) def test_trim_batched_inputs_round_robin(self): - seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]]) - seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) + seq1 = tf.constant([["a", "b", "c"], ["a", "b", "c"]]) + seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]]) packer = MultiSegmentPacker( 7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin" ) @@ -89,7 +89,7 @@ def test_trim_batched_inputs_round_robin(self): def test_trim_batched_inputs_waterfall(self): seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]]) - seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) + seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]]) packer = MultiSegmentPacker( 7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall" )