In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="spacy_transformers.layers")
warnings.filterwarnings("ignore", category=FutureWarning, module="thinc.shims")
warnings.filterwarnings("ignore", category=UserWarning, module="spacy.util")
warnings.filterwarnings("ignore", category=UserWarning, module="spacy_transformers.layers")


In [2]:
import spacy
from spacy.tokens import Doc
from wasabi import msg
from spacy import displacy

In [3]:
nlp = spacy.load("en_core_web_trf")
nlp_coref = spacy.load("en_coreference_web_trf")
# use replace_listeners for the coref components
nlp_coref.replace_listeners("transformer", "coref", ["model.tok2vec"])
nlp_coref.replace_listeners("transformer", "span_resolver", ["model.tok2vec"])

# we won't copy over the span cleaner - this keeps the head cluster information, which we want
nlp.add_pipe("merge_entities")
nlp.add_pipe("coref", source=nlp_coref)
nlp.add_pipe("span_resolver", source=nlp_coref)

<spacy_experimental.coref.span_resolver_component.SpanResolver at 0x730bb60bd5b0>

In [4]:
def resolve_references(doc: Doc) -> str:
    """Function for resolving references with the coref output
    doc (Doc): The doc object processed by coref pipeline
    RETURNS (str): The doc string with resolved references
    """
    # Saves token_id: reference_text
    token_mention_mapper = {}
    output_string = ""
    clusters = [val for key, val in doc.spans.items() if key.startswith("coref_cluster")]

    for cluster in clusters:
        # Saves first span of every cluster
        first_mention = cluster[0]
        # Iterate though every other span in the cluster
        for mentions in list(cluster)[1:]:
            token_mention_mapper[mentions[0].idx] = first_mention.text
            for token in mentions[1:]:
                # Set empty string for every other token
                token_mention_mapper[token.idx] = ""

    # Iterate through every token in doc
    for token in doc:
        
        if token.idx in token_mention_mapper:
            # Check if token exists in token_mention_mapper
            output_string += token_mention_mapper[token.idx] + token.whitespace_
        else:
            # Add original text
            output_string += token.text + token.whitespace_

    return output_string

In [5]:
text = "Philip plays the bass because he loves it"
text = "Sarah enjoiys a nice cup of tea in the morning. She likes it with sugar and a drop of milk."
doc = nlp_coref(text)

In [6]:
msg.info("Pipeline components")
for i, pipe in enumerate(nlp.pipe_names):
    print(f"{i}: {pipe}")

msg.info("Found clusters")
for cluster in doc.spans:
    print(f"{cluster}: {doc.spans[cluster]}")

[38;5;4mℹ Pipeline components[0m
0: transformer
1: tagger
2: parser
3: attribute_ruler
4: lemmatizer
5: ner
6: merge_entities
7: coref
8: span_resolver
[38;5;4mℹ Found clusters[0m
coref_clusters_1: [Sarah, She]
coref_clusters_2: [a nice cup of tea, it]


In [7]:
msg.info("Document with resolved references")
print(resolve_references(doc))

[38;5;4mℹ Document with resolved references[0m
Sarah enjoiys a nice cup of tea in the morning. Sarah likes a nice cup of tea with sugar and a drop of milk.
