Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💫 Allow matching non-ORTH attributes in PhraseMatcher #2925

Merged
merged 5 commits into from Nov 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 36 additions & 5 deletions spacy/matcher.pyx
Expand Up @@ -12,7 +12,7 @@ from .lexeme cimport attr_id_t
from .vocab cimport Vocab
from .tokens.doc cimport Doc
from .tokens.doc cimport get_token_attr
from .attrs cimport ID, attr_id_t, NULL_ATTR
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
from .errors import Errors, TempErrors, Warnings, deprecation_warning

from .attrs import IDS
Expand Down Expand Up @@ -546,16 +546,21 @@ cdef class PhraseMatcher:
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef int max_length
cdef attr_id_t attr
cdef public object _callbacks
cdef public object _patterns

def __init__(self, Vocab vocab, max_length=0):
def __init__(self, Vocab vocab, max_length=0, attr='ORTH'):
if max_length != 0:
deprecation_warning(Warnings.W010)
self.mem = Pool()
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab)
if isinstance(attr, long):
self.attr = attr
else:
self.attr = self.vocab.strings[attr]
self.phrase_ids = PreshMap()
abstract_patterns = [
[{U_ENT: True}],
Expand Down Expand Up @@ -609,7 +614,8 @@ cdef class PhraseMatcher:
tags = get_bilou(length)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, tag in enumerate(tags):
lexeme = self.vocab[doc.c[i].lex.orth]
attr_value = self.get_lex_value(doc, i)
lexeme = self.vocab[attr_value]
lexeme.set_flag(tag, True)
phrase_key[i] = lexeme.orth
phrase_hash = hash64(phrase_key,
Expand All @@ -625,8 +631,16 @@ cdef class PhraseMatcher:
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
matches = []
for _, start, end in self.matcher(doc):
ent_id = self.accept_match(doc, start, end)
if self.attr == ORTH:
match_doc = doc
else:
# If we're not matching on the ORTH, match_doc will be a Doc whose
# token.orth values are the attribute values we're matching on,
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
match_doc = Doc(self.vocab, words=words)
ines marked this conversation as resolved.
Show resolved Hide resolved
for _, start, end in self.matcher(match_doc):
ent_id = self.accept_match(match_doc, start, end)
if ent_id is not None:
matches.append((ent_id, start, end))
for i, (ent_id, start, end) in enumerate(matches):
Expand Down Expand Up @@ -680,6 +694,23 @@ cdef class PhraseMatcher:
else:
return ent_id

def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH:
# Return the regular orth value of the lexeme
return doc.c[i].lex.orth
# Get the attribute value instead, e.g. token.pos
attr_value = get_token_attr(&doc.c[i], self.attr)
if attr_value in (0, 1):
# Value is boolean, convert to string
string_attr_value = str(attr_value)
else:
string_attr_value = self.vocab.strings[attr_value]
string_attr_name = self.vocab.strings[self.attr]
# Concatenate the attr name and value to not pollute lexeme space
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
# create false positive matches
return 'matcher:{}-{}'.format(string_attr_name, string_attr_value)


cdef class DependencyTreeMatcher:
"""Match dependency parse tree based on pattern rules."""
Expand Down
52 changes: 52 additions & 0 deletions spacy/tests/matcher/test_phrase_matcher.py
Expand Up @@ -5,6 +5,8 @@
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc

from ..util import get_doc


def test_matcher_phrase_matcher(en_vocab):
doc = Doc(en_vocab, words=["Google", "Now"])
Expand All @@ -28,3 +30,53 @@ def test_phrase_matcher_contains(en_vocab):
matcher.add('TEST', None, Doc(en_vocab, words=['test']))
assert 'TEST' in matcher
assert 'TEST2' not in matcher


def test_phrase_matcher_string_attrs(en_vocab):
words1 = ['I', 'like', 'cats']
pos1 = ['PRON', 'VERB', 'NOUN']
words2 = ['Yes', ',', 'you', 'hate', 'dogs', 'very', 'much']
pos2 = ['INTJ', 'PUNCT', 'PRON', 'VERB', 'NOUN', 'ADV', 'ADV']
pattern = get_doc(en_vocab, words=words1, pos=pos1)
matcher = PhraseMatcher(en_vocab, attr='POS')
matcher.add('TEST', None, pattern)
doc = get_doc(en_vocab, words=words2, pos=pos2)
matches = matcher(doc)
assert len(matches) == 1
match_id, start, end = matches[0]
assert match_id == en_vocab.strings['TEST']
assert start == 2
assert end == 5


def test_phrase_matcher_string_attrs_negative(en_vocab):
"""Test that token with the control codes as ORTH are *not* matched."""
words1 = ['I', 'like', 'cats']
pos1 = ['PRON', 'VERB', 'NOUN']
words2 = ['matcher:POS-PRON', 'matcher:POS-VERB', 'matcher:POS-NOUN']
pos2 = ['X', 'X', 'X']
pattern = get_doc(en_vocab, words=words1, pos=pos1)
matcher = PhraseMatcher(en_vocab, attr='POS')
matcher.add('TEST', None, pattern)
doc = get_doc(en_vocab, words=words2, pos=pos2)
matches = matcher(doc)
assert len(matches) == 0


def test_phrase_matcher_bool_attrs(en_vocab):
words1 = ['Hello', 'world', '!']
words2 = ['No', 'problem', ',', 'he', 'said', '.']
pattern = Doc(en_vocab, words=words1)
matcher = PhraseMatcher(en_vocab, attr='IS_PUNCT')
matcher.add('TEST', None, pattern)
doc = Doc(en_vocab, words=words2)
matches = matcher(doc)
assert len(matches) == 2
match_id1, start1, end1 = matches[0]
match_id2, start2, end2 = matches[1]
assert match_id1 == en_vocab.strings['TEST']
assert match_id2 == en_vocab.strings['TEST']
assert start1 == 0
assert end1 == 3
assert start2 == 3
assert end2 == 6