Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder #15938

Merged
merged 3 commits into from Mar 7, 2022

Conversation

sanchit-gandhi
Copy link
Contributor

This PR correctly implements a back propagation test to verify the functionality of the freeze_feature_encoder argument added to the FlaxWav2Vec2 Model in #15873. It tests:

  1. That the computed loss for the frozen feature encoder model and unfrozen model are equal.
  2. That the gradients of the frozen feature encoder differ to those of the unfrozen feature encoder.
  3. That the gradients of all other unfrozen layers remain equal.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Very nice test and good to see that no modeling code had to be changed

@sanchit-gandhi
Copy link
Contributor Author

If @patil-suraj is happy with this test I'll merge!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice tests, great job! Just left a comment.

Comment on lines 276 to 283
# ensure that the gradients of the frozen layers differ, i.e. that the feature encoder is properly frozen
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)

for feature_extractor_grad, feature_extractor_grad_frozen in zip(
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we also add one more check to see if the grads of frozen module are all precisely zero, since that's what jax.lax.stop_gradient is supposed to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback! Sure thing, I'll look into that now!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most recent commit (ca918e9) adds an assertion that verifies that the gradients of the frozen feature encoder layers are precisely zero!

@sanchit-gandhi sanchit-gandhi merged commit 1a62b25 into huggingface:master Mar 7, 2022
@sanchit-gandhi sanchit-gandhi deleted the flax-wav2vec2 branch March 8, 2022 17:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants