-
Notifications
You must be signed in to change notification settings - Fork 301
Update generate() to work like fit() and predict() #932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update generate() to work like fit() and predict() #932
Conversation
@fchollet @chenmoneygithub opening up a draft of a refactor for generate. We will probably need some tweaks here, but overall I think we need something like this to keep Long term, we probably factor a lot of code from this into a common generative task class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Matt! The proposed workflow is nice, and the code is cleaner with self.packer
!
Left some comments on implementation details.
output = generate_function(prompt, input_mask, min_length) | ||
def preprocess(x, y=None, sample_weight=None): | ||
if self.preprocessor is not None: | ||
return self.preprocessor(x, sequence_length=max_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use preprocessor
here, I would vote for setting add_end_token=False
by default, which is only useful for chatbot model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently both add_start_token
and add_end_token
default to False right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually ended up flipping the defaults on this and removing end token during generation only. Will add a comment below.
|
||
x = super().call(x) | ||
# Tokenize with one extra token to account for the truncation below. | ||
sequence_length = (sequence_length or self.sequence_length) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit odd to me... sequence_length
has a higher priority than self.sequence_length
because we need to respect max_length
in generate
method, however for users not having this context, this could be weird. Can we override the sequence_length
in generate
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like change the name from max_length
to sequence_length
? no strong preferences there.
let's avoid mutating config state on a layer, but not sure what the best approach is. this just seemed "least bad" to me
db7d1c0
to
d5d1b6d
Compare
/gcbrun |
/gcbrun |
OK! I have addressed comments and tried to get the high level workflows looking the way we want. The major awkwardness we are facing is that preprocessing for generation and fine-tuning look quite different. This PR makes all preprocessing run through Overall I think this is worth it, but the fact that we are shoving two preprocessing flows into a single task & preprocessor is a little awkward. Options I can think of...
Currently I am going with 1., mainly because it is the least disruptive to what we currently have set up. |
9ca7d8f
to
b09ae70
Compare
/gcbrun |
b09ae70
to
5c05dd7
Compare
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The functionality looks good, left some initial comments.
"But I watch youtube while coding!", | ||
] | ||
ds = tf.data.Dataset.from_tensor_slices(features).batch(2) | ||
# Prompt with 50256, the `"<|endoftext|>"` token id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this prompt could be a bit unclear to readers, my guess is we are showing "generate will still do its work if prompt has "<|endoftext|>"
", should we reflect it in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually I think example I was showing calling generate
without preprocessing. So this has nothing to do with our tokenizer.
I updated the example a bit for clarity and remove the endoftext part.
d6947f4
to
3059061
Compare
Comments addressed! |
/gcbrun |
3059061
to
e69b262
Compare
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This seems very cool but I've fallen a bit behind on some context. Most of my questions are asking about simplifications.
Also: How much of this code is specific to GPT-2? Would it makes sense to put this in a base class or does this make sense to rewrite from scratch for OPT?
`tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and | ||
`"padding_mask"`. | ||
max_length: int. The max length of the generated sequence. | ||
batch_size: int. Only pass if `inputs` is a `tf.Tensor`. If set, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the input tensors themselves are batched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user friendly thing to do is just to throw in that case. fit()
and predict()
do not handle batched tensor/numpy input.
If you attempt bert_classifier.predict(batched_tensors)
you would get an error IIUC. Or any vanilla Keras model for that matter.
We could definitely add an error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so one batch at most?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I guess it depend on what you mean by "batched". Do you mean a tensor with shape (num_batches, batch_size, feature shapes)
or do you mean a python list of tensor batches.
The former would not work in fit()
or predict()
etc. The latter might for vanilla Keras? I would need to check.
The main modes of input to the high-level Keras APIs are either a batched dataset, or a single numpy/tensor with shape (num_sample, feature shapes)
and a batch_size
provided separately IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you can't call predict
on a batch? Asking for a friend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that predict has no issues working on batched tensors (colab)
Definitely a good idea. No way we can support this much complexity for each class. I've opened #868 for it a bit ago. My plan was to keep this as a follow up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved!
Dropped some nonblocking comments, let's keep discussing and do roll forward fixes.
prompt, | ||
max_length, | ||
inputs, | ||
max_length=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the back and forth, but I am starting to feel give max_length
a default to model capacity could lead to a bad UX. The vanilla GPT2 won't stop until 1024, so users would need to wait for a while to get results back. But anyway let's get this PR in, and I can open a few followups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's chat! I think we should think a little more generally and long term. How many people actually want exactly 50 unconditioned tokens for GPT2. It's kinda toy I think?
Most real workflows will usually involve fine-tuning on sequences that terminate in someway. E.g. build a summarizer, or a chatbot responder, or a translator. This get's even more true when we think about seq2seq models like T5 and BART, which will follow the same UX we are establishing here.
We want to choose a default that will scale well towards the future states of the library, and I am skeptical that requiring a max_length
in all cases is a good idea.
sample_weight: Any label weight data. Will be passed through unaltered. | ||
sequence_length: Pass to override the configured `sequence_length` of | ||
the layer. | ||
add_start_token: Pass to override the configure value of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(cannot remember if I have posted this or not...)
So in call and init, we both have add_start_token
, but they are actually different types. This could lead to some confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a comment chain above that confused me, but basically if we go with this override approach, we need three possible types in call so it can act as an actual override. None
default use configured value, True
override to true, False
override to false. If you only accept True/False here, you have made the init argument obsolete right?
This updates generate to feel more like fit/predict/evaluate. Inputs to generate can be a dataset, or raw tensors. Inputs can be preprocessed or not depending on if the model has a preprocessor layer attached. The preprocessing layer is used to preprocess all inputs before generation.
e69b262
to
7137316
Compare
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable just a few questions
is attached to the model, inputs should instead be a nested | ||
`tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and | ||
`"padding_mask"`. | ||
max_length: Optional. int. The max length of the generated sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what if there's no preprocessor? In that case this arg appears not to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this arg is worth discussing. max_length
is really conceptually a preprocessing arg and does absolutely nothing if we are passing preprocessed inputs, which should already have shape (batch_size, max_length)
.
This feels like what you would call a denomalized argument somewhat. In that sequence_length
on the preprocessor and max_length
here are really setting the same thing.
We could...
- Think about removing it, though I would do that as a follow perhaps.
- Document that it does nothing when
preprocessor is None
. - Throw an error if
inputs.shape[1] != max_length
.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will go with the lightweight approach of documenting for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sight edit for clarity:
"""
If `preprocessor` is `None`, `inputs` should be padded to the desired maximum
length and this argument will be ignored.
"""
/gcbrun |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work here!
is attached to the model, inputs should instead be a nested | ||
`tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and | ||
`"padding_mask"`. | ||
max_length: Optional. int. The max length of the generated sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sight edit for clarity:
"""
If `preprocessor` is `None`, `inputs` should be padded to the desired maximum
length and this argument will be ignored.
"""
/gcbrun |
This updates generate to feel more like fit/predict/evaluate.
Inputs to generate can be a dataset, or raw tensors. Inputs can be preprocessed or not depending on if the model has a preprocessor layer attached.
The preprocessing layer is used to preprocess all inputs before generation.
Fixes #911, #912, #913 and #844 (if we go with this approach).
Colab with usage