-
Notifications
You must be signed in to change notification settings - Fork 301
Add BartPreprocessor #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BartPreprocessor #856
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
A high level comment about the label - the current version is using causal LM label, but BART seems to have multiple usages, so shall we just let the y
pass by? For specific tasks, we can add special label creator, e.g., BartCausalLMPreprocessor
just like GPT2CausalLMPreprocessor
.
# TODO: Allow users to pass separate `sequence_length`s for encoder and | ||
# decoder. | ||
# Note: We use `MultiSegmentPacker` instead of `StartEndPacker` because | ||
# we might want to support multiple segments in the future (at least for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can drop the "in the future (at least the encoder)", this is used in MNLI, from the paper:
"""
The fine-tuned model concatenates the two sentences
with appended an EOS token, and passes them to both
the BART encoder and decoder. In contrast to BERT,
the representation of the EOS token is used to classify
the sentences relations.
"""
and ["encoder_inputs", "decoder_inputs"] == list(x.keys()) | ||
): | ||
raise ValueError( | ||
f'`x` must be a dictionary, containing the keys `"encoder_inputs"`' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the first line is not an f-string.
|
||
# Get the labels by shifting the decoder inputs one place to the left. | ||
if decoder_token_ids.shape.rank == 1: | ||
y = decoder_token_ids[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is causal LM label, but from the BART paper, causal LM is not the only use, and actually it's only useful in machine translation IIUC. Should we by default just let y
pass by?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, in particular not sure if we are doing the label offset right for seq2seq
sequence_length: The length of the packed inputs. | ||
|
||
Examples: | ||
```python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rework this to follow #843
super().__init__(**kwargs) | ||
self.tokenizer = tokenizer | ||
|
||
# TODO: Allow users to pass separate `sequence_length`s for encoder and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we make an issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved it in this PR itself.
|
||
Args: | ||
tokenizer: A `keras_nlp.models.BartTokenizer` instance. | ||
sequence_length: The length of the packed inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth mentioning that this is the length for both encoder and decoder sequences (for now).
# The last token does not have a next token. Hence, we truncate it. | ||
x = { | ||
**x, | ||
"decoder_token_ids": decoder_token_ids[..., :-1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this actually work as we want? I think this will generate an encoder sequence with length sequence_length
but a decoder sequence with length sequence_length - 1
.
We want both to have both feature sequence have the same length I think, which means we have to tokenize the encoder sequence with length sequence_length
and the decoder with length sequence_length + 1
before the feature label offsetting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw - in that case, we'll need to define two MultiSegmentPackers. Might as well work on #904 in this PR itself instead of saving it for later?
left-to-right manner and fills up the buckets until we run | ||
out of budget. It supports an arbitrary number of segments. | ||
|
||
Examples: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rework all of these pull requests to match the style here #843
|
||
# Tokenize and pack a sentence pair. | ||
inputs = { | ||
"encoder_inputs": ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking over this more, let's keep it simple on the first attempt, and have no support for multiple segments in the base preprocessor layer for now. This will fit with GPT2 code.
IMO this is still just too complicated, and I not sure the use case. For classification, we can support multiple segments, but I don't see the huge need for multiple segments with separate encoder and decoder inputs. Do we have a clear use case there we want to support?
If not, let's land this with the simpler feature set.
|
||
# Tokenize and pack a single sentence. | ||
inputs = { | ||
"encoder_inputs": "The fox was sleeping.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Open question...
Should we call this "encoder_text"
to better accommodate "encoder_audio"
for whisper? Or will it be simpler to have the same names everywhere. I somewhat like the self documenting property of saying this is text input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, "encoder_text"
and "decoder_text"
sound good to me!
/gcbrun |
When should one use Would we ever make a |
@jbischof, Overall, the idea for all model preprocessors is to have a general preprocessor, and then task-specific preprocessors which subclass the general preprocessor. For other models which we have in the library so far,
|
Oops, accidentally closed the PR |
Thanks for the clarification @abheesht17! |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just minor comments
("tf_format", "tf", "model"), | ||
("keras_format", "keras_v3", "model.keras"), | ||
) | ||
def test_saved_model(self, save_format, filename): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"decoder_text": " kohli is the best", | ||
} | ||
|
||
output = self.preprocessor(input_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be much more readable as x, y, sw = self.preprocessor(input_data)
(and use below)
Each value in the dictionary should be a tensor of single string | ||
sequences. Inputs may be batched or unbatched. Raw python inputs | ||
will be converted to tensors. | ||
y: Any label data. Any passed value will be ignored since this is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_output = model(input_data) | ||
restored_model_output = restored_model(input_data) | ||
|
||
self.assertAllEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think assertAllClose
will handle a nested structure here, so you could just assertAllClose(outputs, restored_outputs)
model_output = model(input_data) | ||
restored_model_output = restored_model(input_data) | ||
|
||
self.assertAllEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, this could get considerably shorter with assertAllClose
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Will pull in as soon as testing is done
Looks like the failure is unrelated, so I will pull this in. |
Congrats @abheesht17! |
Thanks, @jbischof! Text generation with BART next up! |
Resolves #904