-
Notifications
You must be signed in to change notification settings - Fork 301
Add StartEndPacker layer #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@abheesht17 will review line by line later, but missed one spot where the functionality is broken here. We need to ensure that the You might be able to do this with regular slicing. Something like...
We should update our unit testing so we check this case somewhere too! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few minor comments. The big thing to fix here is still adding a end token when input length > sequence length.
keras_nlp/layers/start_end_packer.py
Outdated
class StartEndPacker(keras.layers.Layer): | ||
"""Adds start and end tokens to a sequence and pads to a fixed length. | ||
|
||
If inputs are batched, input should be a `tf.RaggedTensor`s with shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to add a short intro paragraph.
This layer is useful when tokenizing inputs for tasks like translation, where each sequence should be marked with a start and end marker. It should be called after tokenization. The layer will first trim inputs to fit, then add start/end tokens, and finally pad if necessary to sequence_length
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Thanks! Pretty much copied it verbatim :)
keras_nlp/layers/start_end_packer.py
Outdated
self.pad_value = pad_value | ||
|
||
def call(self, inputs): | ||
input_is_tensor = isinstance(inputs, tf.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this should probably do the same
if not isinstance(inputs, tf.Tensor) or isinstance(inputs, tf.RaggedTensor):
inputs = tf.conver_to_tensor(inputs)
That will allow us to support things like numpy or list inputs. Useful for demos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏼 . I've removed the else
block below:
else:
raise ValueError(
"Input must be of type `tf.Tensor` or `tf.RaggedTensor`, "
f"but got {type(inputs)}"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor thing - I would rename this to input_is_dense
to match input_is_ragged
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
keras_nlp/layers/start_end_packer.py
Outdated
if input_is_tensor: | ||
if inputs.shape.rank != 1: | ||
raise ValueError( | ||
"Input dense tensor must be of rank 1. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more helpful to say....
Input must be either dense and rank 1 or ragged and rank 2. Received dense input with rank={...}
keras_nlp/layers/start_end_packer.py
Outdated
elif input_is_ragged: | ||
if inputs.shape.rank != 2: | ||
raise ValueError( | ||
"Input ragged tensor must be of rank 2. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar edit here
] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_functional_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a config test
def test_dense_input_error(self): | ||
input_data = tf.constant([[5, 6, 7]]) | ||
start_end_packer = StartEndPacker(sequence_length=5) | ||
with self.assertRaises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please include the error message (just prefix works) to ensure we are capturing the right error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! I've reported the whole error message. Hope that's okay?
def test_ragged_input_error(self): | ||
input_data = tf.ragged.constant([[[5, 6, 7], [8, 9, 10, 11]]]) | ||
start_end_packer = StartEndPacker(sequence_length=5) | ||
with self.assertRaises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw, addressed your comments. Thanks for the review!
keras_nlp/layers/start_end_packer.py
Outdated
class StartEndPacker(keras.layers.Layer): | ||
"""Adds start and end tokens to a sequence and pads to a fixed length. | ||
|
||
If inputs are batched, input should be a `tf.RaggedTensor`s with shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Thanks! Pretty much copied it verbatim :)
keras_nlp/layers/start_end_packer.py
Outdated
self.pad_value = pad_value | ||
|
||
def call(self, inputs): | ||
input_is_tensor = isinstance(inputs, tf.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏼 . I've removed the else
block below:
else:
raise ValueError(
"Input must be of type `tf.Tensor` or `tf.RaggedTensor`, "
f"but got {type(inputs)}"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! just pushed a few copy edits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more bug actually!
keras_nlp/layers/start_end_packer.py
Outdated
f"rank={inputs.shape.rank}" | ||
) | ||
|
||
batch_size = inputs.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually one more issue here. We need to support the case where the static shape of the batch size is None
. I think you will need to call tf.shape(inputs)[0]
to get the dynamic batch size.
We should also add another unit test, using tf.data.map, and calling batch() on the dataset before applying the layer. That should catch this bug.
Fixed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! one minor comment about variable naming
keras_nlp/layers/start_end_packer.py
Outdated
self.pad_value = pad_value | ||
|
||
def call(self, inputs): | ||
input_is_tensor = isinstance(inputs, tf.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor thing - I would rename this to input_is_dense
to match input_is_ragged
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! LGTM!
Resolves #220