Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for PyTorch Models #16425

Conversation

karthikrangasai
Copy link
Contributor

@karthikrangasai karthikrangasai commented Mar 26, 2022

What does this PR do?

Add type hints to as many PyTorch models as possible.

This PR targets the following models to type hint entire files:

  • Albert
  • Bart
  • Bert
  • BertGeneration
  • BigBird
  • BigBirdPegasus
  • Canine
  • ConvBert
  • ConvNext
  • CTRL
  • Data2VecText
  • Data2VecAudio
  • Hubert
  • Marian
  • MBart
  • Nystromformer
  • Wav2Vec2
  • WavLM
  • XGLM
  • XLMRobertaXL
  • Yoso

Any other file that has been edited is a result of running make fix-copies.

In the next PR, I will target few other models to type hint complete files.

Fixes #16059

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@karthikrangasai
Copy link
Contributor Author

Hello all,

I changed the cokkiecutter template code because the check_copies script couldn't correct code that got changed to multiple lines.

Changing a method in Bart model:

##########
## From ##
##########
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):

########
## To ##
########
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
    self,
    attention_mask: torch.Tensor,
    input_shape: torch.Size,
    inputs_embeds: torch.FloatTensor,
    past_key_values_length: int,
) -> Optional[torch.Tensor]:

The code correcter script would make the following change:

# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
    self,
    attention_mask: torch.Tensor,
    input_shape: torch.Size,
    inputs_embeds: torch.FloatTensor,
    past_key_values_length: int,
) -> Optional[torch.Tensor]:

Which caused the errors in one of the test runs.

Just commenting the reason here in case someone is taking a look at this PR later.

@karthikrangasai karthikrangasai changed the title [WIP] Add type hints for PyTorch Models Add type hints for PyTorch Models Mar 28, 2022
@karthikrangasai
Copy link
Contributor Author

karthikrangasai commented Mar 28, 2022

Hello all,

I have updated the code base with type hints for a few models.
I will open a new PR for the remaining models after this one is merged, since this PR is getting bigger.

Thanks

cc: @Rocketknight1

@karthikrangasai
Copy link
Contributor Author

Hello all,

The Add model like runner test is failing with an ImportError when starting to run the Run all PyTorch modeling test section of the tests with the following error:

ImportError: cannot import name 'get_current_traceback' from 'werkzeug.debug.tbtools' (/home/runner/.local/lib/python3.8/site-packages/werkzeug/debug/tbtools.py)

I am unsure as to what is causing the error and any leads on how to resolve this issue would be appreciated.

Thank you

@Rocketknight1
Copy link
Member

Wow, this is a huge PR! Did you do this manually, or have you figured out some kind of tool for it?

@karthikrangasai
Copy link
Contributor Author

Hello @Rocketknight1 ,

Yeah, I made all this manually. This was how I spent my weekend 😛.

@Rocketknight1
Copy link
Member

That's amazing! I'll try to review now.

@Rocketknight1
Copy link
Member

This is a huge and very impressive PR, thank you! The main suggestion I have is that bools are not annotated in some cases, e.g. output_attentions=False should be output_attentions: bool = False, or output_attentions: Optional[bool] = None when the default is None. I'll try to recruit a couple of people from Huggingface to help me review the whole thing once that's resolved!

@karthikrangasai
Copy link
Contributor Author

Hello,

Sure, I can also a take a look once again to fix the missing ones.
I had some doubts with a few of the types and I will post them here later to get the types for them and later update.

Thanks for the update and glad you liked the work.

@Rocketknight1
Copy link
Member

Absolutely! I saw in some cases past was missing annotations - if you're unsure about annotations like that, you can usually check the docstrings for the past or past_key_values argument for that model - it'll be something like Tuple[torch.Tensor]

@karthikrangasai
Copy link
Contributor Author

Sure, later in the process I figured out the type and I had added for a few files. Fill fix for others as well.

@Rocketknight1
Copy link
Member

Note that past/past_key_values can have different structure in different models!

@karthikrangasai
Copy link
Contributor Author

Ohhh, thanks for the heads up.

@Tegzes
Copy link
Contributor

Tegzes commented Mar 30, 2022

@karthikrangasai The best way to make sure the type hints are correct is to check the [Model Name]_INPUTS_DOCSTRING, right before the first user interfaced forward method

@karthikrangasai
Copy link
Contributor Author

Hello @Tegzes ,
I checked that for all forward methods. But it might be possible that I missed it for a few files.

I have type hinted the entire file, from first function to last class. So i might have missed something in other places.

@Rocketknight1
Copy link
Member

Hi @karthikrangasai ! This is totally my bad - other PRs came in and I reviewed them without realizing they would create conflicts with your one. Would it be possible to break this PR up into a few separate ones and submit them one at a time? That greatly reduces the chances of conflicts for each one, and it'll make it possible for me to add specific comments/suggestions, whereas at this size I really can just give general advice!

@karthikrangasai
Copy link
Contributor Author

karthikrangasai commented Apr 1, 2022

Hello @Rocketknight1 ,

Yeah sure.
I would like to completely work on the typing issues if that's fine with you ( for all PyTorch Models - complete file).

I will break the PR into multiple ones based on the corrections made or the model that was type hinted.

Should I close this one then ?

@Rocketknight1
Copy link
Member

Hi @karthikrangasai, sorry for the delay! Yeah, it's probably easiest to close this one, make new ones and just tag me in them. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add missing type hints
4 participants