-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFGPT2ForSequenceClassification based on DialogRPT #8714
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for this very nice addition!!
I left few comments on it. Also can you run the following piece of code and tell me if it works properly:
import tensorflow as tf
from transformers import GPT2Tokenizer, TFGPT2ForSequenceClassification
model = tf.function(TFGPT2ForSequenceClassification.from_pretrained("microsoft/dialogrpt"))
tokenizer = GPT2Tokenizer.from_pretrained("microsoft/dialogrpt")
inputs = tokenizer("Hello", return_tensors="tf")
model(inputs)
@LysandreJik I would recommend as well to wait a bit that the new input processing to be merged.
""", | ||
GPT2_START_DOCSTRING, | ||
) | ||
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be careful, TFSequenceClassificationLoss
takes only into account -100
as pad token id. So either you assume everywhere to have this same value, either you should redefine yourself the loss computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently you forgot this comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFGPT2ForSequenceClassification uses the last token in order to do the classification. So for given task, if a pad_token_id is defined in the configuration, it finds the last token that is not a pad token in each row, If no pad_token_id is defined, it simply takes the last value in each row of the batch.
So I think TFSequenceClassificationLoss loss function should work, as it takes single last token which is not pad token.
Let me know your views on it.
@LysandreJik has already defined pad_token_id: 50256 for this model.
#7493 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, fine for me about this then!
if input_ids is not None: | ||
batch_size, sequence_length = input_ids.shape[:2] | ||
else: | ||
batch_size, sequence_length = inputs_embeds.shape[:2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the shape_list
function instead of the .shape
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
result = tf.map_fn( | ||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" | ||
) | ||
p_logits = tf.reshape(result, [result.shape[0], result.shape[-1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
def call( | ||
self, | ||
inputs, | ||
past=None, | ||
attention_mask=None, | ||
token_type_ids=None, | ||
position_ids=None, | ||
head_mask=None, | ||
inputs_embeds=None, | ||
labels=None, | ||
use_cache=None, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
return_dict=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list is not in the proper order, same think for the input order processing below. Look at the TFGPT2LMHeadModel
class to see an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Thanks for review @jplu . I'll update my code with review comments and new input processing. |
upto date with master
…mplemented review comments and added input processing
…mplemented review comments and added input processing
|
Hello @jplu and @LysandreJik , Kindly review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better!! Thanks for the updates.
There is still one comment to be addressed and the tests to fix.
""", | ||
GPT2_START_DOCSTRING, | ||
) | ||
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently you forgot this comment :)
…mplemented review comments and added input processing
@jplu tests are also fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Tested it out locally, it works great :) Thanks @spatil6!
@@ -114,3 +114,9 @@ TFGPT2DoubleHeadsModel | |||
|
|||
.. autoclass:: transformers.TFGPT2DoubleHeadsModel | |||
:members: call | |||
|
|||
TFGPT2ForSequenceClassification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add the TFSequenceClassifierOutputWithPast
here in the model-specific outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I'll do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@spatil6 we have merged today a PR that updates the way the booleans are processed. You can see an example in the TF BERT file for example, can you rebase and proceed to the same changes please. It would be awesome if you could do it! |
Sure, will do that. |
Fix TF T5 only encoder model with booleans (huggingface#8925)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new inputs processing look good to me. If it looks good to you @jplu, feel free to merge!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!! Good work!!
What does this PR do?
This PR implements TFGPT2ForSequenceClassification in order to support DialogRPT.
Strongly based on modifications made in #7501
Fixes #7622
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.
@LysandreJik Please review this PR, let me know if there is anything that should be changed =)