Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

edit: cast attention_mask to long in DataCollatorCTCWithPadding #19369

Merged

Conversation

ddobokki
Copy link
Contributor

@ddobokki ddobokki commented Oct 6, 2022

What does this PR do?

many inf values generated when training Wav2Vec2ForCTC by referring to run_speech_recognition_ctc.py using DeepSpeed library.

because Wav2Vec2ForCTC's forword has logics that sum attention_mask, so if you training model using DeepSpeed,

def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:

this method cast attention_mask's dtype int32 to float16

Wav2Vec2FeatureExtractor makes attention_mask and it's dtype int32
here is example

import torch
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(return_attention_mask=True)
data = [{'input_values':[0.1,0.1,0.1]},{'input_values':[0.2,0.2,0.2,0.2,0.2]}]
attn_mask = feature_extractor.pad(data,padding = "longest",return_tensors="pt")['attention_mask']
print(attn_mask.dtype)
-> torch.int32

so i add one line in DataCollatorCTCWithPadding that attention_mask casting long type from int32

batch['attention_mask'] = batch['attention_mask'].to(torch.long)

Fixes # 18080

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 6, 2022

The documentation is not available anymore as the PR was closed or merged.

@ddobokki
Copy link
Contributor Author

ddobokki commented Oct 6, 2022

i add more line

if "attention_mask" in batch:

bcz some case, feature_extractor has config that ["return_attention_mask": false]
but is

if self.processor.feature_extractor.return_attention_mask:

more good to read? if so, i'll change that

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@sgugger sgugger merged commit fa4bcd5 into huggingface:main Oct 7, 2022
ajsanjoaquin pushed a commit to ajsanjoaquin/transformers that referenced this pull request Oct 12, 2022
…ingface#19369)

* edit: casting attention_mask to long in DataCollatorCTCWithPadding

* edit: casting attention_mask to long in DataCollatorCTCWithPadding
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 18, 2022
…ingface#19369)

* edit: casting attention_mask to long in DataCollatorCTCWithPadding

* edit: casting attention_mask to long in DataCollatorCTCWithPadding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants