Skip to content

Commit

Permalink
Add type annotations for segformer classes (#16099)
Browse files Browse the repository at this point in the history
  • Loading branch information
p-mishra1 committed Mar 12, 2022
1 parent 9042dfe commit 62b05b6
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions src/transformers/models/segformer/modeling_segformer.py
Expand Up @@ -17,6 +17,7 @@

import collections
import math
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -373,11 +374,11 @@ def __init__(self, config):

def forward(
self,
pixel_values,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

Expand Down Expand Up @@ -501,7 +502,13 @@ class PreTrainedModel
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -556,12 +563,12 @@ def __init__(self, config):
)
def forward(
self,
pixel_values=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
Expand Down Expand Up @@ -715,12 +722,12 @@ def __init__(self, config):
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SemanticSegmentationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
Expand Down

0 comments on commit 62b05b6

Please sign in to comment.