From f45d3886b09e001d697ca9d9d905927b33ae5497 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Wed, 29 Jun 2022 13:32:10 -0700 Subject: [PATCH 1/3] Support for 2D dense tensor --- keras_nlp/layers/multi_segment_packer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index a6cb001de0..35867de1b0 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -40,12 +40,8 @@ class MultiSegmentPacker(keras.layers.Layer): is always 0, and the segment id of each `end_value` is the segment that precedes it. - If inputs are batched, inputs should be `tf.RaggedTensor`s with shape - `[batch_size, None]` and will be packed and converted to a dense tensor with - shape `[batch_size, sequence_length]`. - - If inputs are unbatched, inputs should be dense rank-1 tensors of any shape, - and will be packed to shape `[sequence_length]`. + Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and + either rank-1 or rank-2. Args: sequence_length: The desired output length. @@ -203,7 +199,7 @@ def call(self, inputs): rank_1 = inputs[0].shape.rank == 1 if rank_1: inputs = [tf.expand_dims(x, 0) for x in inputs] - inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs] + inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs] segments = self._trim_inputs(inputs) token_ids, segment_ids = self._combine_inputs(segments) From 37a75836c243549f0e0b1b18e78703cbe03a51b1 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Wed, 29 Jun 2022 14:00:22 -0700 Subject: [PATCH 2/3] fixes --- keras_nlp/layers/multi_segment_packer.py | 16 ++++++++++++---- keras_nlp/layers/multi_segment_packer_test.py | 6 +++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index 35867de1b0..66cd46d221 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -151,6 +151,16 @@ def _sanitize_inputs(self, inputs): ) return inputs + def _convert_inputs(self, inputs): + """Converts inputs to rank 2 ragged tensors.""" + new_inputs = [] + for x in inputs: + if isinstance(x, tf.Tensor): + new_inputs.append(tf.RaggedTensor.from_tensor(x)) + else: + new_inputs.append(x) + return new_inputs + def _trim_inputs(self, inputs): """Trim inputs to desired length.""" num_special_tokens = len(inputs) + 1 @@ -195,22 +205,20 @@ def _combine_inputs(self, segments): def call(self, inputs): inputs = self._sanitize_inputs(inputs) - # If rank 1, add a batch dim and convert to ragged. + # If rank 1, add a batch dim. rank_1 = inputs[0].shape.rank == 1 if rank_1: inputs = [tf.expand_dims(x, 0) for x in inputs] - inputs = [tf.RaggedTensor.from_tensor(x) for x in inputs] + inputs = self._convert_inputs(inputs) segments = self._trim_inputs(inputs) token_ids, segment_ids = self._combine_inputs(segments) - # Pad to dense tensor output. shape = tf.cast([-1, self.sequence_length], "int64") token_ids = token_ids.to_tensor( shape=shape, default_value=self.pad_value ) segment_ids = segment_ids.to_tensor(shape=shape) - # Remove the batch dim if added. if rank_1: token_ids = tf.squeeze(token_ids, 0) diff --git a/keras_nlp/layers/multi_segment_packer_test.py b/keras_nlp/layers/multi_segment_packer_test.py index a03b20cdb9..3655ac7fab 100644 --- a/keras_nlp/layers/multi_segment_packer_test.py +++ b/keras_nlp/layers/multi_segment_packer_test.py @@ -67,8 +67,8 @@ def test_trim_multiple_inputs_waterfall(self): ) def test_trim_batched_inputs_round_robin(self): - seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]]) - seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) + seq1 = tf.constant([["a", "b", "c"], ["a", "b", "c"]]) + seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]]) packer = MultiSegmentPacker( 7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin" ) @@ -89,7 +89,7 @@ def test_trim_batched_inputs_round_robin(self): def test_trim_batched_inputs_waterfall(self): seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]]) - seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) + seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]]) packer = MultiSegmentPacker( 7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall" ) From 4889694d887d22ac779d359688f0182e59f9e6eb Mon Sep 17 00:00:00 2001 From: jessechancy Date: Wed, 29 Jun 2022 15:19:16 -0700 Subject: [PATCH 3/3] style fixes --- keras_nlp/layers/multi_segment_packer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index 66cd46d221..20d2448d1c 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -151,15 +151,12 @@ def _sanitize_inputs(self, inputs): ) return inputs - def _convert_inputs(self, inputs): + def _convert_dense(self, x): """Converts inputs to rank 2 ragged tensors.""" - new_inputs = [] - for x in inputs: - if isinstance(x, tf.Tensor): - new_inputs.append(tf.RaggedTensor.from_tensor(x)) - else: - new_inputs.append(x) - return new_inputs + if isinstance(x, tf.Tensor): + return tf.RaggedTensor.from_tensor(x) + else: + return x def _trim_inputs(self, inputs): """Trim inputs to desired length.""" @@ -209,7 +206,7 @@ def call(self, inputs): rank_1 = inputs[0].shape.rank == 1 if rank_1: inputs = [tf.expand_dims(x, 0) for x in inputs] - inputs = self._convert_inputs(inputs) + inputs = [self._convert_dense(x) for x in inputs] segments = self._trim_inputs(inputs) token_ids, segment_ids = self._combine_inputs(segments)