Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language.replace_listeners: Pass the replaced listener and the tok2vec pipe to the callback #12785

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 27 additions & 12 deletions extra/DEVELOPER_DOCS/Listeners.md
@@ -1,14 +1,17 @@
# Listeners

1. [Overview](#1-overview)
2. [Initialization](#2-initialization)
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [B. Shape inference](#2b-shape-inference)
3. [Internal communication](#3-internal-communication)
- [A. During prediction](#3a-during-prediction)
- [B. During training](#3b-during-training)
- [C. Frozen components](#3c-frozen-components)
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
- [1. Overview](#1-overview)
- [2. Initialization](#2-initialization)
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [2B. Shape inference](#2b-shape-inference)
- [3. Internal communication](#3-internal-communication)
- [3A. During prediction](#3a-during-prediction)
- [3B. During training](#3b-during-training)
- [Training with multiple listeners](#training-with-multiple-listeners)
- [3C. Frozen components](#3c-frozen-components)
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
- [The upstream component is frozen](#the-upstream-component-is-frozen)
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)

## 1. Overview

Expand Down Expand Up @@ -62,7 +65,7 @@ of this `find_listener()` method will specifically identify sublayers of a model

If it's a Transformer-based pipeline, a
[`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py)
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
sublayers of downstream components.

### 2B. Shape inference
Expand Down Expand Up @@ -154,7 +157,7 @@ as a tagger or a parser. This used to be impossible before 3.1, but has become s
embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components)
list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes.

However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`.

#### The upstream component is frozen
Expand Down Expand Up @@ -216,5 +219,17 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
```

The new config and model are then properly stored on the `nlp` object.
Note that this functionality (running the replacement for a transformer listener) was broken prior to
Note that this functionality (running the replacement for a transformer listener) was broken prior to
`spacy-transformers` 1.0.5.

In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
the method only passes these extra arguments for callbacks that support them:

```
def replace_listener_pre_37(copied_tok2vec_model):
...

def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
...
```
2 changes: 2 additions & 0 deletions spacy/errors.py
Expand Up @@ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes):
" 'min_length': {min_length}, 'max_length': {max_length}")
E1054 = ("The text, including whitespace, must match between reference and "
"predicted docs when training {component}.")
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
"but only callbacks with one or three parameters are supported")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
17 changes: 15 additions & 2 deletions spacy/language.py
@@ -1,4 +1,5 @@
import functools
import inspect
import itertools
import multiprocessing as mp
import random
Expand Down Expand Up @@ -2033,8 +2034,20 @@ def replace_listeners(
# Go over the listener layers and replace them
for listener in pipe_listeners:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
if replace_listener_func is not None:
# Pass the extra args to the callback without breaking compatibility with
# old library versions that only expect a single parameter.
num_params = len(
inspect.signature(replace_listener_func).parameters
)
if num_params == 1:
new_model = replace_listener_func(new_model)
elif num_params == 3:
new_model = replace_listener_func(new_model, listener, tok2vec)
else:
raise ValueError(Errors.E1055.format(num_params=num_params))

util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)

Expand Down