Skip to content

Commit

Permalink
clip typhinting huggingface#16059
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdineen committed Mar 12, 2022
1 parent 9442b3c commit 14f00db
Showing 1 changed file with 65 additions and 60 deletions.
125 changes: 65 additions & 60 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -95,12 +95,12 @@ class CLIPOutput(ModelOutput):
"""

loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: Optional[BaseModelOutputWithPooling] = None
vision_model_output: Optional[BaseModelOutputWithPooling] = None

def to_tuple(self) -> Tuple[Any]:
return tuple(
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self, config: CLIPVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))

def forward(self, pixel_values):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
Expand All @@ -150,7 +150,12 @@ def __init__(self, config: CLIPTextConfig):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

def forward(self, input_ids=None, position_ids=None, inputs_embeds=None):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

if position_ids is None:
Expand Down Expand Up @@ -193,7 +198,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -272,7 +277,7 @@ def __init__(self, config):
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
Expand All @@ -293,8 +298,8 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: bool = False,
):
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
Expand Down Expand Up @@ -502,12 +507,12 @@ def __init__(self, config: CLIPConfig):
def forward(
self,
inputs_embeds,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Expand Down Expand Up @@ -600,13 +605,13 @@ def __init__(self, config: CLIPTextConfig):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand Down Expand Up @@ -689,13 +694,13 @@ def set_input_embeddings(self, value):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CLIPTextTransformer:
r"""
Returns:
Expand Down Expand Up @@ -738,11 +743,11 @@ def __init__(self, config: CLIPVisionConfig):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
def forward(
self,
pixel_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand Down Expand Up @@ -798,11 +803,11 @@ def get_input_embeddings(self) -> nn.Module:
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
def forward(
self,
pixel_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> CLIPVisionTransformer:
r"""
Returns:
Expand Down Expand Up @@ -870,13 +875,13 @@ def __init__(self, config: CLIPConfig):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
Expand Down Expand Up @@ -914,7 +919,7 @@ def get_image_features(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
Expand Down Expand Up @@ -953,15 +958,15 @@ def get_image_features(
@replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
def forward(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
return_loss=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPOutput]:
r"""
Returns:
Expand Down

0 comments on commit 14f00db

Please sign in to comment.