Skip to content

Commit

Permalink
Merge pull request #332 from MirkoLenz/fix-yake-normalization
Browse files Browse the repository at this point in the history
Fix normalization for the keyword extractor YAKE
  • Loading branch information
bdewilde committed May 31, 2021
2 parents e5807c1 + fde5e84 commit c950fe0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/textacy/extract/keyterms/yake.py
Expand Up @@ -30,6 +30,7 @@ def yake(
doc: spaCy ``Doc`` from which to extract keyterms.
Must be sentence-segmented; optionally POS-tagged.
normalize: If "lemma", lemmatize terms; if "lower", lowercase terms;
if "norm", use the norm of the terms (as set in a language's tokenizer exceptions);
if None, use the form of terms as they appeared in ``doc``.
.. note:: Unlike the other keyterm extraction functions, this one
Expand Down Expand Up @@ -125,12 +126,12 @@ def yake(

def _get_attr_name(normalize: Optional[str], as_strings: bool) -> str:
if normalize is None:
attr_name = "norm"
elif normalize in ("lemma", "lower"):
attr_name = "orth"
elif normalize in ("lemma", "lower", "norm"):
attr_name = normalize
else:
raise ValueError(
errors.value_invalid_msg("normalize", normalize, {"lemma", "lower", None})
errors.value_invalid_msg("normalize", normalize, {"lemma", "lower", "norm", None})
)
if as_strings is True:
attr_name = attr_name + "_"
Expand Down
12 changes: 10 additions & 2 deletions tests/extract/keyterms/test_yake.py
@@ -1,10 +1,8 @@
import pytest

import textacy
from textacy import datasets
from textacy.extract import keyterms as kt


DATASET = datasets.CapitolWords()

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -34,6 +32,11 @@ def test_default(spacy_doc):
)


def test_normalize_none(spacy_doc):
result = kt.yake(spacy_doc, normalize=None)
assert len(result) > 0


def test_normalize_lower(spacy_doc):
result = kt.yake(spacy_doc, normalize="lower")
assert len(result) > 0
Expand All @@ -46,6 +49,11 @@ def test_normalize_lemma(spacy_doc):
assert any(term != term.lower() for term, _ in result)


def test_normalize_norm(spacy_doc):
result = kt.yake(spacy_doc, normalize="norm")
assert len(result) > 0


def test_ngrams_1(spacy_doc):
result = kt.yake(spacy_doc, ngrams=1)
assert len(result) > 0
Expand Down

0 comments on commit c950fe0

Please sign in to comment.