Skip to content

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jun 10, 2022

Resolves #220

@mattdangerw
Copy link
Member

@abheesht17 will review line by line later, but missed one spot where the functionality is broken here.

We need to ensure that the end token will show up even if the sequence is longer than sequence length. The start and end tokens should show up in the output for all sequences if they are set. Right now I think an overlong sequence will cause the layer to truncate the end token away.

You might be able to do this with regular slicing. Something like...

if self.end_value is not None:
    end_token_id_tensor = tf.fill((batch_size, 1), self.end_value)
    inputs = inputs[..., :sequence_length - 1]  # Trim to leave room for end token.
    inputs = tf.concat([inputs, end_token_id_tensor], axis=-1)

We should update our unit testing so we check this case somewhere too!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a few minor comments. The big thing to fix here is still adding a end token when input length > sequence length.

class StartEndPacker(keras.layers.Layer):
"""Adds start and end tokens to a sequence and pads to a fixed length.

If inputs are batched, input should be a `tf.RaggedTensor`s with shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to add a short intro paragraph.

This layer is useful when tokenizing inputs for tasks like translation, where each sequence should be marked with a start and end marker. It should be called after tokenization. The layer will first trim inputs to fit, then add start/end tokens, and finally pad if necessary to sequence_length.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Thanks! Pretty much copied it verbatim :)

self.pad_value = pad_value

def call(self, inputs):
input_is_tensor = isinstance(inputs, tf.Tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this should probably do the same

if not isinstance(inputs, tf.Tensor) or isinstance(inputs, tf.RaggedTensor):
    inputs = tf.conver_to_tensor(inputs)

That will allow us to support things like numpy or list inputs. Useful for demos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼 . I've removed the else block below:

        else:
            raise ValueError(
                "Input must be of type `tf.Tensor` or `tf.RaggedTensor`, "
                f"but got {type(inputs)}"
            )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor thing - I would rename this to input_is_dense to match input_is_ragged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if input_is_tensor:
if inputs.shape.rank != 1:
raise ValueError(
"Input dense tensor must be of rank 1. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more helpful to say....

Input must be either dense and rank 1 or ragged and rank 2. Received dense input with rank={...}

elif input_is_ragged:
if inputs.shape.rank != 2:
raise ValueError(
"Input ragged tensor must be of rank 2. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar edit here

]
self.assertAllEqual(output, expected_output)

def test_functional_model(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a config test

def test_dense_input_error(self):
input_data = tf.constant([[5, 6, 7]])
start_end_packer = StartEndPacker(sequence_length=5)
with self.assertRaises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include the error message (just prefix works) to ensure we are capturing the right error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I've reported the whole error message. Hope that's okay?

def test_ragged_input_error(self):
input_data = tf.ragged.constant([[[5, 6, 7], [8, 9, 10, 11]]])
start_end_packer = StartEndPacker(sequence_length=5)
with self.assertRaises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw, addressed your comments. Thanks for the review!

class StartEndPacker(keras.layers.Layer):
"""Adds start and end tokens to a sequence and pads to a fixed length.

If inputs are batched, input should be a `tf.RaggedTensor`s with shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Thanks! Pretty much copied it verbatim :)

self.pad_value = pad_value

def call(self, inputs):
input_is_tensor = isinstance(inputs, tf.Tensor)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼 . I've removed the else block below:

        else:
            raise ValueError(
                "Input must be of type `tf.Tensor` or `tf.RaggedTensor`, "
                f"but got {type(inputs)}"
            )

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! just pushed a few copy edits

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more bug actually!

f"rank={inputs.shape.rank}"
)

batch_size = inputs.shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually one more issue here. We need to support the case where the static shape of the batch size is None. I think you will need to call tf.shape(inputs)[0] to get the dynamic batch size.

We should also add another unit test, using tf.data.map, and calling batch() on the dataset before applying the layer. That should catch this bug.

@abheesht17
Copy link
Collaborator Author

One more bug actually!

Fixed!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! one minor comment about variable naming

self.pad_value = pad_value

def call(self, inputs):
input_is_tensor = isinstance(inputs, tf.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor thing - I would rename this to input_is_dense to match input_is_ragged.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM!

@mattdangerw mattdangerw merged commit 3c89e51 into keras-team:master Jun 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a start and end token packer layer
3 participants