-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Code samples; Test code samples #5036
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5036 +/- ##
==========================================
+ Coverage 79.08% 79.30% +0.22%
==========================================
Files 138 138
Lines 24078 24265 +187
==========================================
+ Hits 19041 19243 +202
+ Misses 5037 5022 -15
Continue to review full report at Codecov.
|
This is amazing! This way we won't do as many mistakes while copy-pasting code for introducing those task-specific models :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I like this idea a lot! Will be easier to add more Tensorflow examples as well!
src/transformers/file_utils.py
Outdated
PYTORCH_MULTIPLE_CHOICE_CODE_SAMPLE_DOCSTRING = r""" | ||
Examples:: | ||
|
||
from transformers import BertTokenizer, BertForMultipleChoice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be « {tokenizer_class}, {model_class} «
src/transformers/file_utils.py
Outdated
from transformers import BertTokenizer, BertForMultipleChoice | ||
import torch | ||
|
||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change also here and below
src/transformers/file_utils.py
Outdated
choice1 = "It is eaten while held in the hand." | ||
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 | ||
|
||
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And we probably want to remove batch_encode_plus here hahah
18bd438
to
929682e
Compare
929682e
to
574575a
Compare
@@ -633,7 +610,7 @@ def call( | |||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids | |||
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions | |||
assert len(inputs) <= 8, "Too many inputs." | |||
elif isinstance(inputs, dict): | |||
elif isinstance(inputs, (dict, BatchEncoding)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-cosmetic change
self.roberta = TFBertMainLayer(config, name="roberta") | ||
self.roberta = TFRobertaMainLayer(config, name="roberta") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
@@ -863,7 +843,7 @@ def call( | |||
labels = inputs[4] if len(inputs) > 4 else labels | |||
output_attentions = inputs[5] if len(inputs) > 5 else output_attentions | |||
assert len(inputs) <= 6, "Too many inputs." | |||
elif isinstance(inputs, dict): | |||
elif isinstance(inputs, (BatchEncoding, dict)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
assert len(inputs) <= 10, "Too many inputs." | ||
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states | ||
labels = inputs[11] if len(inputs) > 11 else labels | ||
assert len(inputs) <= 11, "Too many inputs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
assert len(inputs) <= 10, "Too many inputs." | ||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) | ||
labels = inputs.get("labels", labels) | ||
assert len(inputs) <= 12, "Too many inputs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" | ||
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
if isinstance(inputs, (tuple, list)): | ||
labels = inputs[11] if len(inputs) > 11 else labels | ||
if len(inputs) > 11: | ||
inputs = inputs[:11] | ||
elif isinstance(inputs, (dict, BatchEncoding)): | ||
labels = inputs.pop("labels", labels) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
outputs = (logits,) + transformer_outputs[2:] # add hidden states and attention if they are here | ||
outputs = (logits,) + transformer_outputs[1:] # add hidden states and attention if they are here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non cosmetic change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work!
I just highlighted the examples I saw without doctest syntax, just to make sure it was on purpose.
Also there is one nit in quicktour.rst
@@ -978,36 +978,38 @@ def generate( | |||
|
|||
Examples:: | |||
|
|||
from transformers import AutoTokenizer, AutoModelForCausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't manage to have a clean doctest with such a big example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huge work!
Amazing!
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Refactor code samples * Test docstrings * Style * Tokenization examples * Run rust of tests * First step to testing source docs * Style and BART comment * Test the remainder of the code samples * Style * let to const * Formatting fixes * Ready for merge * Fix fixture + Style * Fix last tests * Update docs/source/quicktour.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing @sgugger's comments + Fix MobileBERT in TF Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Refactoring the code samples in order to prevent copy/pasting the same code samples across classes while updating the model/tokenizer classes and checkpoint names.