Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFGPT2ForSequenceClassification based on DialogRPT #8714

Merged
merged 15 commits into from
Dec 7, 2020

Conversation

spatil6
Copy link
Contributor

@spatil6 spatil6 commented Nov 22, 2020

What does this PR do?

This PR implements TFGPT2ForSequenceClassification in order to support DialogRPT.
Strongly based on modifications made in #7501
Fixes #7622

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@LysandreJik Please review this PR, let me know if there is anything that should be changed =)

Copy link
Contributor

@jplu jplu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for this very nice addition!!

I left few comments on it. Also can you run the following piece of code and tell me if it works properly:

import tensorflow as tf
from transformers import GPT2Tokenizer, TFGPT2ForSequenceClassification

model = tf.function(TFGPT2ForSequenceClassification.from_pretrained("microsoft/dialogrpt"))
tokenizer = GPT2Tokenizer.from_pretrained("microsoft/dialogrpt")
inputs = tokenizer("Hello", return_tensors="tf")
model(inputs)

@LysandreJik I would recommend as well to wait a bit that the new input processing to be merged.

""",
GPT2_START_DOCSTRING,
)
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful, TFSequenceClassificationLoss takes only into account -100 as pad token id. So either you assume everywhere to have this same value, either you should redefine yourself the loss computation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently you forgot this comment :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFGPT2ForSequenceClassification uses the last token in order to do the classification. So for given task, if a pad_token_id is defined in the configuration, it finds the last token that is not a pad token in each row, If no pad_token_id is defined, it simply takes the last value in each row of the batch.

So I think TFSequenceClassificationLoss loss function should work, as it takes single last token which is not pad token.

Let me know your views on it.

@LysandreJik has already defined pad_token_id: 50256 for this model.
#7493 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

This is snapshot of TFSequenceClassificationLoss, Not any mention of pad token id : -100 is aregetting ignored.

There are other loss function like TFTokenClassificationLoss,TFNextSentencePredictionLoss where toke id : -100 is getting ignored.

Copy link
Contributor

@jplu jplu Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, fine for me about this then!

Comment on lines 944 to 947
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the shape_list function instead of the .shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
p_logits = tf.reshape(result, [result.shape[0], result.shape[-1]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 855 to 868
def call(
self,
inputs,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is not in the proper order, same think for the input order processing below. Look at the TFGPT2LMHeadModel class to see an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@spatil6
Copy link
Contributor Author

spatil6 commented Nov 25, 2020

Thanks for review @jplu . I'll update my code with review comments and new input processing.

@spatil6
Copy link
Contributor Author

spatil6 commented Nov 27, 2020

Thank you very much for this very nice addition!!

I left few comments on it. Also can you run the following piece of code and tell me if it works properly:

import tensorflow as tf
from transformers import GPT2Tokenizer, TFGPT2ForSequenceClassification

model = tf.function(TFGPT2ForSequenceClassification.from_pretrained("microsoft/dialogrpt"))
tokenizer = GPT2Tokenizer.from_pretrained("microsoft/dialogrpt")
inputs = tokenizer("Hello", return_tensors="tf")
model(inputs)

@LysandreJik I would recommend as well to wait a bit that the new input processing to be merged.

output

@spatil6
Copy link
Contributor Author

spatil6 commented Nov 27, 2020

Hello @jplu and @LysandreJik ,
I have refactored code as per review comments and added new input processing as well.

Kindly review.

@spatil6 spatil6 requested a review from jplu November 28, 2020 07:28
Copy link
Contributor

@jplu jplu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better!! Thanks for the updates.

There is still one comment to be addressed and the tests to fix.

""",
GPT2_START_DOCSTRING,
)
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently you forgot this comment :)

@spatil6 spatil6 requested a review from jplu November 30, 2020 16:07
@spatil6
Copy link
Contributor Author

spatil6 commented Dec 2, 2020

Much better!! Thanks for the updates.

There is still one comment to be addressed and the tests to fix.

@jplu tests are also fixed now.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Tested it out locally, it works great :) Thanks @spatil6!

@@ -114,3 +114,9 @@ TFGPT2DoubleHeadsModel

.. autoclass:: transformers.TFGPT2DoubleHeadsModel
:members: call

TFGPT2ForSequenceClassification
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add the TFSequenceClassifierOutputWithPast here in the model-specific outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'll do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@jplu
Copy link
Contributor

jplu commented Dec 4, 2020

@spatil6 we have merged today a PR that updates the way the booleans are processed. You can see an example in the TF BERT file for example, can you rebase and proceed to the same changes please. It would be awesome if you could do it!

@spatil6
Copy link
Contributor Author

spatil6 commented Dec 4, 2020

@spatil6 we have merged today a PR that updates the way the booleans are processed. You can see an example in the TF BERT file for example, can you rebase and proceed to the same changes please. It would be awesome if you could do it!

Sure, will do that.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new inputs processing look good to me. If it looks good to you @jplu, feel free to merge!

Copy link
Contributor

@jplu jplu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!! Good work!!

@jplu jplu merged commit 483e132 into huggingface:master Dec 7, 2020
@spatil6 spatil6 deleted the tf2_gpt2_sequence_model branch December 19, 2020 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement a TF2 version of GPT2ForSequenceClassification
3 participants