Skip to content

Commit

Permalink
Fix range in Span.get_lca_matrix (#8115)
Browse files Browse the repository at this point in the history
Fix the adjusted token index / lca matrix index ranges for
`_get_lca_matrix` for spans.

* The range for `k` should correspond to the adjusted indices in
`lca_matrix` with the `start` indexed at `0`
  • Loading branch information
adrianeboyd committed May 17, 2021
1 parent 6ce9f04 commit 5e7e7cd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions spacy/tests/doc/test_span.py
Expand Up @@ -2,6 +2,8 @@
from __future__ import unicode_literals

import pytest
import numpy
from numpy.testing import assert_array_equal
from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
Expand Down Expand Up @@ -118,6 +120,14 @@ def test_spans_lca_matrix(en_tokenizer):
assert lca[1, 0] == 1 # slept & dog -> slept
assert lca[1, 1] == 1 # slept & slept -> slept

# example from Span API docs
tokens = en_tokenizer("I like New York in Autumn")
doc = get_doc(
tokens.vocab, words=[t.text for t in tokens], heads=[1, 0, 1, -2, -1, -1]
)
lca = doc[1:4].get_lca_matrix()
assert_array_equal(lca, numpy.asarray([[0, 0, 0], [0, 1, 2], [0, 2, 2]]))


def test_span_similarity_match():
doc = Doc(Vocab(), words=["a", "b", "a", "b"])
Expand Down
2 changes: 1 addition & 1 deletion spacy/tokens/doc.pyx
Expand Up @@ -1351,7 +1351,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
j_idx_in_sent = start + j - sent_start
n_missing_tokens_in_sent = len(sent) - j_idx_in_sent
# make sure we do not go past `end`, in cases where `end` < sent.end
max_range = min(j + n_missing_tokens_in_sent, end)
max_range = min(j + n_missing_tokens_in_sent, end - start)
for k in range(j + 1, max_range):
lca = _get_tokens_lca(token_j, doc[start + k])
# if lca is outside of span, we set it to -1
Expand Down

0 comments on commit 5e7e7cd

Please sign in to comment.