Skip to content

Commit

Permalink
Models docstring (#21225)
Browse files Browse the repository at this point in the history
* Clean all models

* Style

* Last to remove

* address review comments

* Address review comments
  • Loading branch information
sgugger committed Jan 23, 2023
1 parent 9e86c4e commit fd5cdae
Show file tree
Hide file tree
Showing 246 changed files with 1,285 additions and 1,890 deletions.
3 changes: 1 addition & 2 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,10 +1124,9 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)


def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
processor_class=tokenizer_class,
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
Expand Down
23 changes: 6 additions & 17 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

_CHECKPOINT_FOR_DOC = "albert-base-v2"
_CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer"


ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -582,7 +581,7 @@ class AlbertForPreTrainingOutput(ModelOutput):
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AlbertTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.
[What are input IDs?](../glossary#input-ids)
Expand Down Expand Up @@ -677,7 +676,6 @@ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -818,10 +816,10 @@ def forward(
Example:
```python
>>> from transformers import AlbertTokenizer, AlbertForPreTraining
>>> from transformers import AutoTokenizer, AlbertForPreTraining
>>> import torch
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
>>> model = AlbertForPreTraining.from_pretrained("albert-base-v2")
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
Expand Down Expand Up @@ -967,9 +965,9 @@ def forward(
```python
>>> import torch
>>> from transformers import AlbertTokenizer, AlbertForMaskedLM
>>> from transformers import AutoTokenizer, AlbertForMaskedLM
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
>>> model = AlbertForMaskedLM.from_pretrained("albert-base-v2")
>>> # add mask_token
Expand Down Expand Up @@ -1048,7 +1046,6 @@ def __init__(self, config: AlbertConfig):

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="textattack/albert-base-v2-imdb",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1157,15 +1154,9 @@ def __init__(self, config: AlbertConfig):

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="vumichien/tiny-albert",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=(
"['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', "
"'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']"
),
expected_loss=0.66,
)
def forward(
self,
Expand Down Expand Up @@ -1243,7 +1234,6 @@ def __init__(self, config: AlbertConfig):

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="twmkn9/albert-base-v2-squad2",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1347,7 +1337,6 @@ def __init__(self, config: AlbertConfig):

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down
19 changes: 5 additions & 14 deletions src/transformers/models/albert/modeling_flax_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

_CHECKPOINT_FOR_DOC = "albert-base-v2"
_CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer"


@flax.struct.dataclass
Expand Down Expand Up @@ -122,7 +121,7 @@ class FlaxAlbertForPreTrainingOutput(ModelOutput):
input_ids (`numpy.ndarray` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AlbertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
Expand Down Expand Up @@ -680,9 +679,7 @@ class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
module_class = FlaxAlbertModule


append_call_sample_docstring(
FlaxAlbertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)
append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)


class FlaxAlbertForPreTrainingModule(nn.Module):
Expand Down Expand Up @@ -757,9 +754,9 @@ class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
Example:
```python
>>> from transformers import AlbertTokenizer, FlaxAlbertForPreTraining
>>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
>>> model = FlaxAlbertForPreTraining.from_pretrained("albert-base-v2")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
Expand Down Expand Up @@ -834,9 +831,7 @@ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
module_class = FlaxAlbertForMaskedLMModule


append_call_sample_docstring(
FlaxAlbertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
)
append_call_sample_docstring(FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)


class FlaxAlbertForSequenceClassificationModule(nn.Module):
Expand Down Expand Up @@ -906,7 +901,6 @@ class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):

append_call_sample_docstring(
FlaxAlbertForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -983,7 +977,6 @@ class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
)
append_call_sample_docstring(
FlaxAlbertForMultipleChoice,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxMultipleChoiceModelOutput,
_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1054,7 +1047,6 @@ class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):

append_call_sample_docstring(
FlaxAlbertForTokenClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxTokenClassifierOutput,
_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1123,7 +1115,6 @@ class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):

append_call_sample_docstring(
FlaxAlbertForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
Expand Down
23 changes: 6 additions & 17 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@

_CHECKPOINT_FOR_DOC = "albert-base-v2"
_CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"albert-base-v1",
Expand Down Expand Up @@ -738,7 +737,7 @@ class TFAlbertForPreTrainingOutput(ModelOutput):
input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AlbertTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.
[What are input IDs?](../glossary#input-ids)
Expand Down Expand Up @@ -802,7 +801,6 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -895,9 +893,9 @@ def call(
```python
>>> import tensorflow as tf
>>> from transformers import AlbertTokenizer, TFAlbertForPreTraining
>>> from transformers import AutoTokenizer, TFAlbertForPreTraining
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
>>> model = TFAlbertForPreTraining.from_pretrained("albert-base-v2")
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
Expand Down Expand Up @@ -1015,9 +1013,9 @@ def call(
```python
>>> import tensorflow as tf
>>> from transformers import AlbertTokenizer, TFAlbertForMaskedLM
>>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
>>> model = TFAlbertForMaskedLM.from_pretrained("albert-base-v2")
>>> # add mask_token
Expand Down Expand Up @@ -1101,7 +1099,6 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="vumichien/albert-base-v2-imdb",
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1196,15 +1193,9 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="vumichien/tiny-albert",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=(
"['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', "
"'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']"
),
expected_loss=0.66,
)
def call(
self,
Expand Down Expand Up @@ -1285,7 +1276,6 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="vumichien/albert-base-v2-squad2",
output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1400,7 +1390,6 @@ def dummy_inputs(self):
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down
29 changes: 14 additions & 15 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

logger = logging.get_logger(__name__)

_TOKENIZER_FOR_DOC = "XLMRobertaTokenizer"
_CHECKPOINT_FOR_DOC = "BAAI/AltCLIP"
_CONFIG_FOR_DOC = "AltCLIPConfig"

Expand Down Expand Up @@ -68,7 +67,7 @@
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`XLMRobertaTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
Expand Down Expand Up @@ -98,7 +97,7 @@
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand All @@ -115,7 +114,7 @@
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`XLMRobertaTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
Expand All @@ -133,7 +132,7 @@
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Expand Down Expand Up @@ -1181,10 +1180,10 @@ def forward(
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AltCLIPProcessor, AltCLIPVisionModel
>>> from transformers import AutoProcessor, AltCLIPVisionModel
>>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP")
>>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
Expand Down Expand Up @@ -1422,10 +1421,10 @@ def forward(
Examples:
```python
>>> from transformers import AltCLIPProcessor, AltCLIPTextModel
>>> from transformers import AutoProcessor, AltCLIPTextModel
>>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP")
>>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
>>> texts = ["it's a cat", "it's a dog"]
Expand Down Expand Up @@ -1526,10 +1525,10 @@ def get_text_features(
Examples:
```python
>>> from transformers import AltCLIPProcessor, AltCLIPModel
>>> from transformers import AutoProcessor, AltCLIPModel
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
>>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
Expand Down Expand Up @@ -1572,10 +1571,10 @@ def get_image_features(
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AltCLIPProcessor, AltCLIPModel
>>> from transformers import AutoProcessor, AltCLIPModel
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
>>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
Expand Down Expand Up @@ -1622,10 +1621,10 @@ def forward(
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AltCLIPProcessor, AltCLIPModel
>>> from transformers import AutoProcessor, AltCLIPModel
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
>>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
Expand Down
Loading

0 comments on commit fd5cdae

Please sign in to comment.