Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Code samples; Test code samples #5036

Merged
merged 17 commits into from
Jun 25, 2020
Merged

Conversation

LysandreJik
Copy link
Member

@LysandreJik LysandreJik commented Jun 15, 2020

Refactoring the code samples in order to prevent copy/pasting the same code samples across classes while updating the model/tokenizer classes and checkpoint names.

  • All models now have their docstrings updated.
  • Doctest is used for testing
  • Fixed a bunch of bugs in all docstrings as well as a few models. All non-cosmetic changes are highlighted below.

@LysandreJik LysandreJik requested a review from sgugger June 15, 2020 22:17
@codecov
Copy link

codecov bot commented Jun 15, 2020

Codecov Report

Merging #5036 into master will increase coverage by 0.22%.
The diff coverage is 97.44%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5036      +/-   ##
==========================================
+ Coverage   79.08%   79.30%   +0.22%     
==========================================
  Files         138      138              
  Lines       24078    24265     +187     
==========================================
+ Hits        19041    19243     +202     
+ Misses       5037     5022      -15     
Impacted Files Coverage Δ
src/transformers/configuration_albert.py 100.00% <ø> (ø)
src/transformers/configuration_bart.py 93.75% <ø> (ø)
src/transformers/configuration_bert.py 100.00% <ø> (ø)
src/transformers/configuration_ctrl.py 97.05% <ø> (ø)
src/transformers/configuration_distilbert.py 100.00% <ø> (ø)
src/transformers/configuration_electra.py 100.00% <ø> (ø)
src/transformers/configuration_encoder_decoder.py 100.00% <ø> (ø)
src/transformers/configuration_gpt2.py 97.22% <ø> (ø)
src/transformers/configuration_longformer.py 100.00% <ø> (ø)
src/transformers/configuration_mobilebert.py 97.05% <ø> (ø)
... and 50 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 24f46ea...a9bb134. Read the comment docs.

@sgugger
Copy link
Collaborator

sgugger commented Jun 15, 2020

This is amazing! This way we won't do as many mistakes while copy-pasting code for introducing those task-specific models :-)

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I like this idea a lot! Will be easier to add more Tensorflow examples as well!

PYTORCH_MULTIPLE_CHOICE_CODE_SAMPLE_DOCSTRING = r"""
Examples::

from transformers import BertTokenizer, BertForMultipleChoice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be « {tokenizer_class}, {model_class} « 

from transformers import BertTokenizer, BertForMultipleChoice
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change also here and below

choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1

encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we probably want to remove batch_encode_plus here hahah

@LysandreJik LysandreJik changed the title Refactor Code samples [DO NOT MERGE] Refactor Code samples Jun 24, 2020
@LysandreJik LysandreJik changed the title [DO NOT MERGE] Refactor Code samples Refactor Code samples; Test code samples Jun 25, 2020
@@ -633,7 +610,7 @@ def call(
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-cosmetic change

Comment on lines -444 to +417
self.roberta = TFBertMainLayer(config, name="roberta")
self.roberta = TFRobertaMainLayer(config, name="roberta")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

@@ -863,7 +843,7 @@ def call(
labels = inputs[4] if len(inputs) > 4 else labels
output_attentions = inputs[5] if len(inputs) > 5 else output_attentions
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (BatchEncoding, dict)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Comment on lines -935 to +895
assert len(inputs) <= 10, "Too many inputs."
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
labels = inputs[11] if len(inputs) > 11 else labels
assert len(inputs) <= 11, "Too many inputs."
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Comment on lines -947 to +909
assert len(inputs) <= 10, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 12, "Too many inputs."
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Comment on lines -1004 to +966
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Comment on lines +1008 to +1014
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Comment on lines -1075 to +1032
outputs = (logits,) + transformer_outputs[2:] # add hidden states and attention if they are here
outputs = (logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non cosmetic change

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work!
I just highlighted the examples I saw without doctest syntax, just to make sure it was on purpose.
Also there is one nit in quicktour.rst

docs/source/quicktour.rst Outdated Show resolved Hide resolved
src/transformers/modeling_tf_mobilebert.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_mobilebert.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_openai.py Outdated Show resolved Hide resolved
@@ -978,36 +978,38 @@ def generate(

Examples::

from transformers import AutoTokenizer, AutoModelForCausalLM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't manage to have a clean doctest with such a big example

src/transformers/modeling_xlm.py Show resolved Hide resolved
Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work!
Amazing!

LysandreJik and others added 3 commits June 25, 2020 15:55
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@LysandreJik LysandreJik merged commit 364a5ae into master Jun 25, 2020
@LysandreJik LysandreJik deleted the refactor-code-samples branch June 25, 2020 20:46
jplu pushed a commit to jplu/transformers that referenced this pull request Jun 29, 2020
* Refactor code samples

* Test docstrings

* Style

* Tokenization examples

* Run rust of tests

* First step to testing source docs

* Style and BART comment

* Test the remainder of the code samples

* Style

* let to const

* Formatting fixes

* Ready for merge

* Fix fixture + Style

* Fix last tests

* Update docs/source/quicktour.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing @sgugger's comments + Fix MobileBERT in TF

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants