Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language.replace_listeners: Pass the replaced listener and the tok2vec pipe to the callback #12785

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 29 additions & 12 deletions extra/DEVELOPER_DOCS/Listeners.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Listeners

1. [Overview](#1-overview)
2. [Initialization](#2-initialization)
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [B. Shape inference](#2b-shape-inference)
3. [Internal communication](#3-internal-communication)
- [A. During prediction](#3a-during-prediction)
- [B. During training](#3b-during-training)
- [C. Frozen components](#3c-frozen-components)
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
- [Listeners](#listeners)
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
- [1. Overview](#1-overview)
- [2. Initialization](#2-initialization)
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
- [2B. Shape inference](#2b-shape-inference)
- [3. Internal communication](#3-internal-communication)
- [3A. During prediction](#3a-during-prediction)
- [3B. During training](#3b-during-training)
- [Training with multiple listeners](#training-with-multiple-listeners)
- [3C. Frozen components](#3c-frozen-components)
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
- [The upstream component is frozen](#the-upstream-component-is-frozen)
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)

## 1. Overview

Expand Down Expand Up @@ -62,7 +66,7 @@ of this `find_listener()` method will specifically identify sublayers of a model

If it's a Transformer-based pipeline, a
[`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py)
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
sublayers of downstream components.

### 2B. Shape inference
Expand Down Expand Up @@ -154,7 +158,7 @@ as a tagger or a parser. This used to be impossible before 3.1, but has become s
embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components)
list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes.

However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`.

#### The upstream component is frozen
Expand Down Expand Up @@ -216,5 +220,18 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
```

The new config and model are then properly stored on the `nlp` object.
Note that this functionality (running the replacement for a transformer listener) was broken prior to
Note that this functionality (running the replacement for a transformer listener) was broken prior to
`spacy-transformers` 1.0.5.

In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
the listener to be replaced, the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
the method only passes these extra arguments for callbacks that support them:

```
def replace_listener_pre_37(copied_tok2vec_model):
...

def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
...

```
18 changes: 16 additions & 2 deletions spacy/language.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
import itertools
import multiprocessing as mp
import random
Expand Down Expand Up @@ -2033,8 +2034,21 @@ def replace_listeners(
# Go over the listener layers and replace them
for listener in pipe_listeners:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
if replace_listener_func is not None:
# Pass the extra args to the callback without breaking compatibility with
# old library versions that only expect a single parameter.
num_params = len(
inspect.signature(replace_listener_func).parameters
)
assert num_params in (
1,
3,
), f"'replace_listener' callback expects {num_params} parameters, but we only support one or three"
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
if num_params == 1:
new_model = replace_listener_func(new_model)
else:
new_model = replace_listener_func(new_model, listener, tok2vec)
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)

Expand Down