New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder #15938
Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder #15938
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
85ad7e8
to
bf29d36
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Very nice test and good to see that no modeling code had to be changed
If @patil-suraj is happy with this test I'll merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice tests, great job! Just left a comment.
# ensure that the gradients of the frozen layers differ, i.e. that the feature encoder is properly frozen | ||
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) | ||
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) | ||
|
||
for feature_extractor_grad, feature_extractor_grad_frozen in zip( | ||
feature_extractor_grads, feature_extractor_grads_frozen | ||
): | ||
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we also add one more check to see if the grads of frozen module are all precisely zero, since that's what jax.lax.stop_gradient
is supposed to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the feedback! Sure thing, I'll look into that now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The most recent commit (ca918e9
) adds an assertion that verifies that the gradients of the frozen feature encoder layers are precisely zero!
This PR correctly implements a back propagation test to verify the functionality of the
freeze_feature_encoder
argument added to the FlaxWav2Vec2 Model in #15873. It tests: