Skip to content

Commit

Permalink
KBQA fixes (#1591)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrijeuseew committed Nov 5, 2022
1 parent d39818c commit 2bc186b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
28 changes: 15 additions & 13 deletions deeppavlov/models/entity_extraction/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self.use_tags = use_tags
self.full_paragraph = full_paragraph
self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
self.not_found_str = "not in wiki"
self.not_found_str = "not_in_wiki"

self.load()

Expand Down Expand Up @@ -277,27 +277,31 @@ def process_cand_ent(self, cand_ent_init, entities_and_ids, entity_substr_split,
cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels))
return cand_ent_init

def find_title(self, entity_substr):
entities_and_ids = []
try:
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
entities_and_ids = res.fetchall()
except sqlite3.OperationalError as e:
log.debug(f"error in searching an entity {e}")
return entities_and_ids

def find_exact_match(self, entity_substr, tag):
entity_substr_split = entity_substr.split()
cand_ent_init = defaultdict(set)
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
entities_and_ids = res.fetchall()
entities_and_ids = self.find_title(entity_substr)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
if entity_substr.startswith("the "):
entity_substr = entity_substr.split("the ")[1]
entity_substr_split = entity_substr_split[1:]
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr))
entities_and_ids = res.fetchall()
entities_and_ids = self.find_title(entity_substr)
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
entity_substr_lemm = " ".join(entity_substr_split_lemm)
if entity_substr_lemm != entity_substr:
res = self.cur.execute(
"SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr_lemm)
)
entities_and_ids = res.fetchall()
entities_and_ids = self.find_title(entity_substr_lemm)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(
cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag
Expand All @@ -311,14 +315,12 @@ def find_fuzzy_match(self, entity_substr_split, tag):
entity_substr_split_lemm = entity_substr_split
cand_ent_init = defaultdict(set)
for word in entity_substr_split:
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(word))
part_entities_and_ids = res.fetchall()
part_entities_and_ids = self.find_title(word)
cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
word_lemm = self.morph.parse(word)[0].normal_form
if word != word_lemm:
res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(word_lemm))
part_entities_and_ids = res.fetchall()
part_entities_and_ids = self.find_title(word_lemm)
cand_ent_init = self.process_cand_ent(
cand_ent_init,
part_entities_and_ids,
Expand Down
7 changes: 5 additions & 2 deletions deeppavlov/models/kbqa/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,11 @@ def query_parser(self, question: str, query_info: Dict[str, str],
rel_combs = make_combs(rels, permut=False)
entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')]
log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}")
selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0]
selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0]
selected_entity_ids, selected_type_ids = [], []
if entity_ids:
selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0]
if type_ids:
selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0]
entity_combs = make_combs(selected_entity_ids, permut=True)
type_combs = make_combs(selected_type_ids, permut=False)
log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]},"
Expand Down
5 changes: 3 additions & 2 deletions deeppavlov/models/kbqa/rel_ranking_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ def __call__(self, questions_list: List[str],
answer_ids = answers_with_scores[0][0]
if self.return_all_possible_answers and isinstance(answer_ids, tuple):
answer_ids_input = [(answer_id, question) for answer_id in answer_ids]
answer_ids = [answer_id.split("/")[-1] for answer_id in answer_ids]
answer_ids = list(map(lambda x: x.split("/")[-1] if str(x).startswith("http") else x, answer_ids))
else:
answer_ids_input = [(answer_ids, question)]
answer_ids = answer_ids.split("/")[-1]
if str(answer_ids).startswith("http:"):
answer_ids = answer_ids.split("/")[-1]
parser_info_list = ["find_label" for _ in answer_ids_input]
answer_labels = self.wiki_parser(parser_info_list, answer_ids_input)
log.debug(f"answer_labels {answer_labels}")
Expand Down
5 changes: 3 additions & 2 deletions deeppavlov/models/kbqa/type_define.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st
break
elif token.head.text == type_noun and token.dep_ == "prep":
if len(list(token.children)) == 1 \
and not any([list(token.children)[0] in entity_substr.lower()
and not any([[tok.text for tok in token.children][0] in entity_substr.lower()
for entity_substr in entity_substr_list]):
types_substr += [token.text, list(token.children)[0]]
types_substr += [token.text, [tok.text for tok in token.children][0]]
elif any([word in question for word in self.pronouns]):
for token in doc:
if token.dep_ == "nsubj" and not any([token.text in entity_substr.lower()
for entity_substr in entity_substr_list]):
types_substr.append(token.text)

types_substr = [(token, token_pos_dict[token]) for token in types_substr]
types_substr = sorted(types_substr, key=lambda x: x[1])
types_substr = " ".join([elem[0] for elem in types_substr])
Expand Down

0 comments on commit 2bc186b

Please sign in to comment.