Skip to content

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Mar 16, 2023

Resolves #904

@abheesht17 abheesht17 requested a review from mattdangerw March 16, 2023 06:46
Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

A high level comment about the label - the current version is using causal LM label, but BART seems to have multiple usages, so shall we just let the y pass by? For specific tasks, we can add special label creator, e.g., BartCausalLMPreprocessor just like GPT2CausalLMPreprocessor.

# TODO: Allow users to pass separate `sequence_length`s for encoder and
# decoder.
# Note: We use `MultiSegmentPacker` instead of `StartEndPacker` because
# we might want to support multiple segments in the future (at least for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the "in the future (at least the encoder)", this is used in MNLI, from the paper:
"""
The fine-tuned model concatenates the two sentences
with appended an EOS token, and passes them to both
the BART encoder and decoder. In contrast to BERT,
the representation of the EOS token is used to classify
the sentences relations.
"""

and ["encoder_inputs", "decoder_inputs"] == list(x.keys())
):
raise ValueError(
f'`x` must be a dictionary, containing the keys `"encoder_inputs"`'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the first line is not an f-string.


# Get the labels by shifting the decoder inputs one place to the left.
if decoder_token_ids.shape.rank == 1:
y = decoder_token_ids[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causal LM label, but from the BART paper, causal LM is not the only use, and actually it's only useful in machine translation IIUC. Should we by default just let y pass by?

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, in particular not sure if we are doing the label offset right for seq2seq

sequence_length: The length of the packed inputs.

Examples:
```python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rework this to follow #843

super().__init__(**kwargs)
self.tokenizer = tokenizer

# TODO: Allow users to pass separate `sequence_length`s for encoder and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make an issue for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved it in this PR itself.


Args:
tokenizer: A `keras_nlp.models.BartTokenizer` instance.
sequence_length: The length of the packed inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth mentioning that this is the length for both encoder and decoder sequences (for now).

# The last token does not have a next token. Hence, we truncate it.
x = {
**x,
"decoder_token_ids": decoder_token_ids[..., :-1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this actually work as we want? I think this will generate an encoder sequence with length sequence_length but a decoder sequence with length sequence_length - 1.

We want both to have both feature sequence have the same length I think, which means we have to tokenize the encoder sequence with length sequence_length and the decoder with length sequence_length + 1 before the feature label offsetting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw - in that case, we'll need to define two MultiSegmentPackers. Might as well work on #904 in this PR itself instead of saving it for later?

left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Examples:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rework all of these pull requests to match the style here #843

@mattdangerw mattdangerw self-assigned this Mar 22, 2023
@abheesht17 abheesht17 requested a review from mattdangerw March 24, 2023 21:57

# Tokenize and pack a sentence pair.
inputs = {
"encoder_inputs": (
Copy link
Member

@mattdangerw mattdangerw Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking over this more, let's keep it simple on the first attempt, and have no support for multiple segments in the base preprocessor layer for now. This will fit with GPT2 code.

IMO this is still just too complicated, and I not sure the use case. For classification, we can support multiple segments, but I don't see the huge need for multiple segments with separate encoder and decoder inputs. Do we have a clear use case there we want to support?

If not, let's land this with the simpler feature set.


# Tokenize and pack a single sentence.
inputs = {
"encoder_inputs": "The fox was sleeping.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open question...

Should we call this "encoder_text" to better accommodate "encoder_audio" for whisper? Or will it be simpler to have the same names everywhere. I somewhat like the self documenting property of saying this is text input.

Copy link
Collaborator Author

@abheesht17 abheesht17 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, "encoder_text" and "decoder_text" sound good to me!

@mattdangerw
Copy link
Member

/gcbrun

@jbischof
Copy link
Contributor

jbischof commented Apr 3, 2023

When should one use BartPreprocessor vs BartSeq2SeqPreprocessor? Just reading the docstring I am a little confused.

Would we ever make a BartClassifier subclass which passes a single input sequence to both encoder and decoder?

@abheesht17
Copy link
Collaborator Author

abheesht17 commented Apr 3, 2023

@jbischof, BartPreprocessor is meant to be a very general layer. If the user wants to do something funky, he/she/they will go with BartPreprocessor. We do not expect users to use this layer often; most usecases will be satisfied with BartSeq2SeqLMPreprocessor.

Overall, the idea for all model preprocessors is to have a general preprocessor, and then task-specific preprocessors which subclass the general preprocessor. For other models which we have in the library so far, {model}Preprocessor = {model}ClassifierPreprocessor; we'll have to make an alias (for other models).

BartClassifierPreprocessor will be added in a follow-up PR.

@abheesht17 abheesht17 closed this Apr 3, 2023
@abheesht17 abheesht17 reopened this Apr 3, 2023
@abheesht17
Copy link
Collaborator Author

Oops, accidentally closed the PR

@jbischof
Copy link
Contributor

jbischof commented Apr 3, 2023

Thanks for the clarification @abheesht17!

@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just minor comments

("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
)
def test_saved_model(self, save_format, filename):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two changes we can make for efficiency here...

#945 (don't save traces)
#894 (mark as large, make separate serialization test)

"decoder_text": " kohli is the best",
}

output = self.preprocessor(input_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be much more readable as x, y, sw = self.preprocessor(input_data) (and use below)

Each value in the dictionary should be a tensor of single string
sequences. Inputs may be batched or unbatched. Raw python inputs
will be converted to tensors.
y: Any label data. Any passed value will be ignored since this is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_output = model(input_data)
restored_model_output = restored_model(input_data)

self.assertAllEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think assertAllClose will handle a nested structure here, so you could just assertAllClose(outputs, restored_outputs)

model_output = model(input_data)
restored_model_output = restored_model(input_data)

self.assertAllEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, this could get considerably shorter with assertAllClose

@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Will pull in as soon as testing is done

@mattdangerw
Copy link
Member

Looks like the failure is unrelated, so I will pull this in.

@mattdangerw mattdangerw merged commit c046ab6 into keras-team:master Apr 4, 2023
@jbischof
Copy link
Contributor

jbischof commented Apr 4, 2023

Congrats @abheesht17!

@abheesht17
Copy link
Collaborator Author

Thanks, @jbischof! Text generation with BART next up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow Passing Separate Sequence Lengths for Encoder Input and Decoder Input in BartPreprocessor
4 participants