New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix the gradient backward issue when joint training with s3prl frontend #5159
Conversation
5c62225
to
33aa097
Compare
3434492
to
d3dbd76
Compare
Codecov Report
@@ Coverage Diff @@
## master #5159 +/- ##
==========================================
- Coverage 74.99% 74.99% -0.01%
==========================================
Files 618 618
Lines 55588 55589 +1
==========================================
Hits 41689 41689
- Misses 13899 13900 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
@Emrys365, can you review this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I can also verify the gradient issue is resolved on my side.
if getattr( | ||
upstream.upstream, "model", None | ||
) is not None and upstream.upstream.model.__class__.__name__ in [ | ||
"Wav2Vec2Model", | ||
"HubertModel", | ||
]: | ||
upstream.upstream.model.encoder.layerdrop = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these lines removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is because S3PRL already sets encoder_layerdrop=0
when initializing an upstream, so we no longer need to do this in ESPnet. Am I right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I can also verify the gradient issue is resolved on my side.
Thanks a lot for fixing it! |
In s3prl,
feature_grad_mult
was set to 0. Thus the forward is in the context oftorch.no_grad()
here. It will stop the gradient back-propagation in joint training. To fix it, just set it to be 1 manually.Also in this PR, the encoder layerdrop part is removed, because it is set to 0 from s3prl for all upstreams, e.g. WavLM.