Skip to content

Commit

Permalink
refactor: pymorphy2 to spacy in Entity Linking and KBQA (#1618)
Browse files Browse the repository at this point in the history
Co-authored-by: Fedor Ignatov <ignatov.fedor@gmail.com>
  • Loading branch information
dmitrijeuseew and IgnatovFedor committed Jan 23, 2023
1 parent 9ff98b6 commit 3ee1b85
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 40 deletions.
9 changes: 7 additions & 2 deletions deeppavlov/core/common/requirements_registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
],
"entity_linker": [
"{DEEPPAVLOV_PATH}/requirements/hdt.txt",
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt"
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt",
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
],
"fasttext": [
"{DEEPPAVLOV_PATH}/requirements/fasttext.txt"
Expand Down Expand Up @@ -58,6 +60,7 @@
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
],
"ru_adj_to_noun": [
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
],
"russian_words_vocab": [
Expand Down Expand Up @@ -147,7 +150,9 @@
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
],
"tree_to_sparql": [
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
"{DEEPPAVLOV_PATH}/requirements/udapi.txt",
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
],
"typos_custom_reader": [
"{DEEPPAVLOV_PATH}/requirements/lxml.txt"
Expand Down
39 changes: 15 additions & 24 deletions deeppavlov/models/entity_extraction/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

import re
import sqlite3
from collections import defaultdict
from logging import getLogger
from typing import List, Dict, Tuple, Union, Any
from collections import defaultdict

import pymorphy2
import spacy
from hdt import HDTDocument
from nltk.corpus import stopwords
from rapidfuzz import fuzz

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.commands.utils import expand_path

log = getLogger(__name__)

Expand Down Expand Up @@ -75,7 +75,6 @@ def __init__(
**kwargs:
"""
super().__init__(save_path=None, load_path=load_path)
self.morph = pymorphy2.MorphAnalyzer()
self.lemmatize = lemmatize
self.entities_database_filename = entities_database_filename
self.num_entities_for_bert_ranking = num_entities_for_bert_ranking
Expand All @@ -86,8 +85,10 @@ def __init__(
self.lang = f"@{lang}"
if self.lang == "@en":
self.stopwords = set(stopwords.words("english"))
self.nlp = spacy.load("en_core_web_sm")
elif self.lang == "@ru":
self.stopwords = set(stopwords.words("russian"))
self.nlp = spacy.load("ru_core_news_sm")
self.use_descriptions = use_descriptions
self.use_connections = use_connections
self.max_paragraph_len = max_paragraph_len
Expand Down Expand Up @@ -198,7 +199,7 @@ def link_entities(
):
cand_ent_scores = []
if len(entity_substr) > 1:
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
cand_ent_init = self.find_exact_match(entity_substr, tag)
if not cand_ent_init or entity_substr_split != entity_substr_split_lemm:
cand_ent_init = self.find_fuzzy_match(entity_substr_split, tag)
Expand Down Expand Up @@ -297,28 +298,23 @@ def find_exact_match(self, entity_substr, tag):
entity_substr_split = entity_substr_split[1:]
entities_and_ids = self.find_title(entity_substr)
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
entity_substr_lemm = " ".join(entity_substr_split_lemm)
if entity_substr_lemm != entity_substr:
entities_and_ids = self.find_title(entity_substr_lemm)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(
cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag
)

entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
entity_substr_lemm = " ".join(entity_substr_split_lemm)
if entity_substr_lemm != entity_substr:
entities_and_ids = self.find_title(entity_substr_lemm)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag)
return cand_ent_init

def find_fuzzy_match(self, entity_substr_split, tag):
if self.lang == "@ru":
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
else:
entity_substr_split_lemm = entity_substr_split
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
cand_ent_init = defaultdict(set)
for word in entity_substr_split:
part_entities_and_ids = self.find_title(word)
cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
word_lemm = self.morph.parse(word)[0].normal_form
word_lemm = self.nlp(word)[0].lemma_
if word != word_lemm:
part_entities_and_ids = self.find_title(word_lemm)
cand_ent_init = self.process_cand_ent(
Expand All @@ -329,11 +325,6 @@ def find_fuzzy_match(self, entity_substr_split, tag):
)
return cand_ent_init

def morph_parse(self, word):
morph_parse_tok = self.morph.parse(word)[0]
normal_form = morph_parse_tok.normal_form
return normal_form

def calc_substr_score(self, cand_entity_title, entity_substr_split):
label_tokens = cand_entity_title.split()
cnt = 0.0
Expand Down
20 changes: 10 additions & 10 deletions deeppavlov/models/kbqa/tree_to_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, List, Tuple, Dict, Union

import numpy as np
import pymorphy2
import spacy
from navec import Navec
from scipy.sparse import csr_matrix
from slovnet import Syntax
Expand Down Expand Up @@ -66,11 +66,10 @@ def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, **kwargs)
self.adj_set = set([word for word, freq in pos_freq_dict["a"]])
self.nouns = [noun[0] for noun in self.nouns_with_freq]
self.matrix = self.make_sparse_matrix(self.nouns).transpose()
self.morph = pymorphy2.MorphAnalyzer()
self.nlp = spacy.load("ru_core_news_sm")

def search(self, word: str):
word = self.morph.parse(word)[0]
word = word.normal_form
word = self.nlp(word)[0].lemma_
if word in self.adj_set:
q_matrix = self.make_sparse_matrix([word])
scores = q_matrix * self.matrix
Expand Down Expand Up @@ -190,6 +189,7 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
self.begin_tokens = {"начинать", "начать"}
self.end_tokens = {"завершить", "завершать", "закончить"}
self.ranking_tokens = {"самый"}
self.nlp = spacy.load("ru_core_news_sm")
elif self.lang == "eng":
self.q_pronouns = {"what", "who", "how", "when", "where", "which"}
self.how_many = "how many"
Expand All @@ -199,12 +199,12 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
self.begin_tokens = set()
self.end_tokens = set()
self.ranking_tokens = set()
self.nlp = spacy.load("en_core_web_sm")
else:
raise ValueError(f"unsupported language {lang}")
self.sparql_queries_filename = expand_path(sparql_queries_filename)
self.template_queries = read_json(self.sparql_queries_filename)
self.adj_to_noun = adj_to_noun
self.morph = pymorphy2.MorphAnalyzer()

def __call__(self, syntax_tree_batch: List[str],
positions_batch: List[List[List[int]]]) -> Tuple[
Expand Down Expand Up @@ -274,7 +274,7 @@ def __call__(self, syntax_tree_batch: List[str],
self.root_entity = True

temporal_order = self.find_first_last(new_root)
new_root_nf = self.morph.parse(new_root.form)[0].normal_form
new_root_nf = self.nlp(new_root.form)[0].lemma_
if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens:
temporal_order = new_root_nf
ranking_tokens = self.find_ranking_tokens(new_root)
Expand All @@ -288,7 +288,7 @@ def __call__(self, syntax_tree_batch: List[str],
question = []
for node in tree.descendants:
if node.ord in ranking_tokens or node.form.lower() in self.q_pronouns:
question.append(self.morph.parse(node.form)[0].normal_form)
question.append(self.nlp(node.form)[0].lemma_)
else:
question.append(node.form)
question = ' '.join(question)
Expand Down Expand Up @@ -496,9 +496,9 @@ def find_first_last(self, node: Node) -> str:
for node in nodes:
node_desc = defaultdict(set)
for elem in node.children:
parsed_elem = self.morph.parse(elem.form.lower())[0].inflect({"masc", "sing", "nomn"})
parsed_elem = self.nlp(elem.form.lower())[0].lemma_
if parsed_elem is not None:
node_desc[elem.deprel].add(parsed_elem.word)
node_desc[elem.deprel].add(parsed_elem)
else:
node_desc[elem.deprel].add(elem.form)
if "amod" in node_desc.keys() and "nmod" in node_desc.keys() and \
Expand All @@ -511,7 +511,7 @@ def find_first_last(self, node: Node) -> str:
def find_ranking_tokens(self, node: Node) -> list:
ranking_tokens = []
for elem in node.descendants:
if self.morph.parse(elem.form)[0].normal_form in self.ranking_tokens:
if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens:
ranking_tokens.append(elem.ord)
ranking_tokens.append(elem.parent.ord)
return ranking_tokens
Expand Down
4 changes: 1 addition & 3 deletions deeppavlov/models/kbqa/type_define.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pickle
from typing import List

import pymorphy2
import spacy
from nltk.corpus import stopwords

Expand Down Expand Up @@ -43,7 +42,6 @@ def __init__(self, lang: str, types_filename: str, types_sets_filename: str,
self.types_filename = str(expand_path(types_filename))
self.types_sets_filename = str(expand_path(types_sets_filename))
self.num_types_to_return = num_types_to_return
self.morph = pymorphy2.MorphAnalyzer()
if self.lang == "@en":
self.stopwords = set(stopwords.words("english"))
self.nlp = spacy.load("en_core_web_sm")
Expand Down Expand Up @@ -102,7 +100,7 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st
types_substr_tokens = types_substr.split()
types_substr_tokens = [tok for tok in types_substr_tokens if tok not in self.stopwords]
if self.lang == "@ru":
types_substr_tokens = [self.morph.parse(tok)[0].normal_form for tok in types_substr_tokens]
types_substr_tokens = [self.nlp(tok)[0].lemma_ for tok in types_substr_tokens]
types_substr_tokens = set(types_substr_tokens)
types_scores = []
for entity in self.types_dict:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
("kbqa/kbqa_cq_ru.json", "kbqa", ('IP',)):
[
("Кто такой Оксимирон?", ("российский рэп-исполнитель",)),
("Чем питаются коалы?", ("Лист",)),
("Кто написал «Евгений Онегин»?", ("Александр Сергеевич Пушкин",)),
("абв", ("Not Found",))
]
},
Expand Down

0 comments on commit 3ee1b85

Please sign in to comment.