Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer.predict: do not broadcast to listeners #345

Merged
merged 6 commits into from Jan 30, 2023

Conversation

danieldk
Copy link
Contributor

The output of a transformer is passed through in two different ways:

  • Prediction: the data is passed through the Doc._.trf_data attribute.
  • Training: the data is broadcast directly to the transformer's listeners.

However, the Transformer.predict method breaks the strict separation between training and prediction by also broadcasting transformer outputs to its listeners.

However, this breaks down when we are training a model with an unfrozen transformer when the transformer is also in annotating_components. The transformer will first (as part of its update step) broadcast the tensors and backprop function to its listeners. However, then when acting as an annotating component, it would immediately override its own output and clear the backprop function. As a result, gradients will not flow into the transformer.

This change removes the broadcast from the predict method. If a listener does not receive a batch, attempt to get the transformer output from the Doc instances. This makes it possible to train a pipeline with a frozen transformer.

This ports explosion/spaCy#11385 to spacy-transformers. Alternative to #342.

The output of a transformer is passed through in two different ways:

- Prediction: the data is passed through the `Doc._.trf_data` attribute.
- Training: the data is broadcast directly to the transformer's
  listeners.

However, the `Transformer.predict` method breaks the strict separation between
training and prediction by also broadcasting transformer outputs to its
listeners.

However, this breaks down when we are training a model with an unfrozen
transformer when the transformer is also in `annotating_components`. The
transformer will first (as part of its update step) broadcast the tensors and
backprop function to its listeners. However, then when acting as an annotating
component, it would immediately override its own output and clear the backprop
function. As a result, gradients will not flow into the transformer.

This change removes the broadcast from the `predict` method. If a listener does
not receive a batch, attempt to get the transformer output from the `Doc`
instances. This makes it possible to train a pipeline with a frozen transformer.
@danieldk danieldk added bug Something isn't working feat / pipeline Feature: Pipeline components labels Aug 31, 2022
@danieldk danieldk marked this pull request as ready for review January 26, 2023 08:50
@svlandeg svlandeg merged commit e66c73d into explosion:master Jan 30, 2023
adrianeboyd added a commit to adrianeboyd/spacy-transformers that referenced this pull request Feb 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feat / pipeline Feature: Pipeline components
Projects
None yet
2 participants