Skip to content

Commit

Permalink
Replace tensorflow syntax parser with pytorch syntax parser from slov…
Browse files Browse the repository at this point in the history
…net (#1569)

Co-authored-by: Fedor Ignatov <ignatov.fedor@gmail.com>
  • Loading branch information
dmitrijeuseew and IgnatovFedor committed Jun 20, 2022
1 parent d6774bb commit ef2ac99
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 18 deletions.
11 changes: 9 additions & 2 deletions deeppavlov/configs/kbqa/kbqa_cq_ru.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
"lang": "@ru"
},
{
"config_path": "{CONFIGS_PATH}/syntax/syntax_ru_syntagrus_bert.json",
"in": ["x_punct"],
"class_name": "slovnet_syntax_parser",
"load_path": "{MODELS_PATH}/slovnet_syntax_parser",
"navec_filename": "{MODELS_PATH}/slovnet_syntax_parser/navec_news_v1_1B_250K_300d_100q.tar",
"syntax_parser_filename": "{MODELS_PATH}/slovnet_syntax_parser/slovnet_syntax_news_v1.tar",
"in": ["x_punct", "entity_offsets"],
"out": ["syntax_info"]
},
{
Expand Down Expand Up @@ -103,6 +106,10 @@
{
"url": "http://files.deeppavlov.ai/kbqa/wikidata/kbqa_files_ru.tar.gz",
"subdir": "{DOWNLOADS_PATH}/wikidata_rus"
},
{
"url": "http://files.deeppavlov.ai/deeppavlov_data/slovnet_syntax_parser.tar.gz",
"subdir": "{MODELS_PATH}/slovnet_syntax_parser"
}
]
}
Expand Down
1 change: 1 addition & 0 deletions deeppavlov/core/common/registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"siamese_iterator": "deeppavlov.dataset_iterators.siamese_iterator:SiameseIterator",
"simple_vocab": "deeppavlov.core.data.simple_vocab:SimpleVocabulary",
"sklearn_component": "deeppavlov.models.sklearn.sklearn_component:SklearnComponent",
"slovnet_syntax_parser": "deeppavlov.models.kbqa.tree_to_sparql:SlovnetSyntaxParser",
"spelling_error_model": "deeppavlov.models.spelling_correction.brillmoore.error_model:ErrorModel",
"spelling_levenshtein": "deeppavlov.models.spelling_correction.levenshtein.searcher_component:LevenshteinSearcherComponent",
"split_tokenizer": "deeppavlov.models.tokenizers.split_tokenizer:SplitTokenizer",
Expand Down
3 changes: 3 additions & 0 deletions deeppavlov/core/common/requirements_registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
"russian_words_vocab": [
"{DEEPPAVLOV_PATH}/requirements/lxml.txt"
],
"slovnet_syntax_parser": [
"{DEEPPAVLOV_PATH}/requirements/slovnet.txt"
],
"spelling_error_model": [
"{DEEPPAVLOV_PATH}/requirements/lxml.txt"
],
Expand Down
7 changes: 3 additions & 4 deletions deeppavlov/models/entity_extraction/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __call__(
start = end + 1
sentences_offsets_batch.append(sentences_offsets_list)

if entity_tags_batch is None:
entity_tags_batch = [[] for _ in entity_substr_batch]
if entity_tags_batch is None or not entity_tags_batch[0]:
entity_tags_batch = [["" for _ in entity_substr_list] for entity_substr_list in entity_substr_batch]
else:
entity_tags_batch = [[tag.upper() for tag in entity_tags] for entity_tags in entity_tags_batch]

Expand Down Expand Up @@ -264,7 +264,7 @@ def link_entities(
def process_cand_ent(self, cand_ent_init, entities_and_ids, entity_substr_split, tag):
if self.use_tags:
for cand_entity_title, cand_entity_id, cand_entity_rels, cand_tag, *_ in entities_and_ids:
if tag == cand_tag:
if not tag or tag == cand_tag:
substr_score = self.calc_substr_score(cand_entity_title, entity_substr_split)
cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels))
if not cand_ent_init:
Expand Down Expand Up @@ -511,7 +511,6 @@ def rank_by_description(
contexts.append(context)

scores_list = self.entity_ranker(contexts, cand_ent_list, cand_ent_descr_list)

for (entity_substr, candidate_entities, substr_len, entities_scores, scores,) in zip(
entity_substr_list,
cand_ent_list,
Expand Down
2 changes: 1 addition & 1 deletion deeppavlov/models/kbqa/rel_ranking_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __call__(self, questions_list: List[str],
answer_ids = answer_ids.split("/")[-1]
parser_info_list = ["find_label" for _ in answer_ids_input]
answer_labels = self.wiki_parser(parser_info_list, answer_ids_input)
log.info(f"answer_labels {answer_labels}")
log.debug(f"answer_labels {answer_labels}")
if self.return_all_possible_answers:
answer_labels = list(set(answer_labels))
answer_labels = [label for label in answer_labels if (label and label != "Not Found")][:5]
Expand Down
73 changes: 65 additions & 8 deletions deeppavlov/models/kbqa/tree_to_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections import defaultdict
from io import StringIO
from typing import Any, List, Tuple, Dict, Union
from logging import getLogger
from collections import defaultdict
from typing import Any, List, Tuple, Dict, Union

import numpy as np
import pymorphy2
import re
from navec import Navec
from scipy.sparse import csr_matrix
from slovnet import Syntax
from udapi.block.read.conllu import Conllu
from udapi.core.node import Node

from deeppavlov.core.models.component import Component
from deeppavlov.core.common.file import read_json
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.file import read_json
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable

log = getLogger(__name__)

Expand Down Expand Up @@ -110,6 +113,57 @@ def make_sparse_matrix(self, words: List[str]):
return matrix


@register('slovnet_syntax_parser')
class SlovnetSyntaxParser(Component, Serializable):
"""Class for syntax parsing using Slovnet library"""

def __init__(self, load_path: str, navec_filename: str, syntax_parser_filename: str, **kwargs):
super().__init__(save_path=None, load_path=load_path)
self.navec_filename = expand_path(navec_filename)
self.syntax_parser_filename = expand_path(syntax_parser_filename)
self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
self.load()

def load(self) -> None:
navec = Navec.load(self.navec_filename)
self.syntax = Syntax.load(self.syntax_parser_filename)
self.syntax.navec(navec)

def save(self) -> None:
pass

def __call__(self, sentences, entity_offsets_batch):
sentences_tok = []
for sentence, entity_offsets in zip(sentences, entity_offsets_batch):
for start, end in entity_offsets:
entity_old = sentence[start:end]
entity_new = entity_old.capitalize()
sentence = sentence.replace(entity_old, entity_new)
sentence = sentence.capitalize()
sentences_tok.append(re.findall(self.re_tokenizer, sentence))
markup = list(self.syntax.map(sentences_tok))

processed_markup_batch = []
for markup_elem in markup:
processed_markup = []
ids, words, head_ids, rels = [], [], [], []
for elem in markup_elem.tokens:
ids.append(elem.id)
words.append(elem.text)
head_ids.append(elem.head_id)
rels.append(elem.rel)
if "root" not in {rel.lower() for rel in rels}:
for n, (elem_id, head_id) in enumerate(zip(ids, head_ids)):
if elem_id == head_id:
rels[n] = "root"
head_ids[n] = 0
for elem_id, word, head_id, rel in zip(ids, words, head_ids, rels):
processed_markup.append(f"{elem_id}\t{word}\t_\t_\t_\t_\t{head_id}\t{rel}\t_\t_")
processed_markup_batch.append("\n".join(processed_markup))

return processed_markup_batch


@register('tree_to_sparql')
class TreeToSparql(Component):
"""
Expand Down Expand Up @@ -163,9 +217,12 @@ def __call__(self, syntax_tree_batch: List[str],
count = False
for syntax_tree, positions in zip(syntax_tree_batch, positions_batch):
log.debug(f"\n{syntax_tree}")
tree = Conllu(filehandle=StringIO(syntax_tree)).read_tree()
root = self.find_root(tree)
tree_desc = tree.descendants
try:
tree = Conllu(filehandle=StringIO(syntax_tree)).read_tree()
root = self.find_root(tree)
tree_desc = tree.descendants
except ValueError:
root = ""
unknown_node = ""
if root:
log.debug(f"syntax tree info, root: {root.form}")
Expand Down
2 changes: 2 additions & 0 deletions deeppavlov/requirements/slovnet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
slovnet==0.5.0
navec==0.10.0
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@

# -- Extension configuration -------------------------------------------------

autodoc_mock_imports = ['bert_dp', 'bs4', 'fastText', 'fasttext', 'hdt', 'kenlm', 'librosa', 'lxml', 'nltk',
'opt_einsum', 'rapidfuzz', 'rasa', 'russian_tagsets', 'sacremoses', 'sortedcontainers', 'spacy',
'tensorflow', 'torch', 'torchcrf', 'transformers', 'udapi', 'ufal_udpipe', 'whapi']
autodoc_mock_imports = ['bert_dp', 'bs4', 'fastText', 'fasttext', 'hdt', 'kenlm', 'librosa', 'lxml', 'navec', 'nltk',
'opt_einsum', 'rapidfuzz', 'rasa', 'russian_tagsets', 'sacremoses', 'slovnet',
'sortedcontainers', 'spacy', 'tensorflow', 'torch', 'torchcrf', 'transformers', 'udapi',
'ufal_udpipe', 'whapi']

extlinks = {
'config': (f'https://github.com/deepmipt/DeepPavlov/blob/{release}/deeppavlov/configs/%s', None)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ numpy<=1.22.2
overrides<=4.1.2
pandas<=1.4.0
prometheus-client<=0.13.1
protobuf<4
pytz<=2021.3
pydantic<=1.9.0
pymorphy2<=0.9.1
Expand Down

0 comments on commit ef2ac99

Please sign in to comment.