Skip to content

Commit

Permalink
fixed searches returning searched words or documents
Browse files Browse the repository at this point in the history
  • Loading branch information
ddangelov committed Oct 17, 2020
1 parent abb2e7d commit fa13a67
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions top2vec/Top2Vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,8 +1443,19 @@ def similar_words(self, keywords, num_words, keywords_neg=None):
word_vecs = self._get_word_vectors(keywords)
neg_word_vecs = self._get_word_vectors(keywords_neg)
combined_vector = self._get_combined_vec(word_vecs, neg_word_vecs)

num_res = min(num_words + len(keywords) + len(keywords_neg), len(self.vocab))

word_indexes, word_scores = self._search_vectors_by_vector(self.word_vectors,
combined_vector, num_words)
combined_vector, num_res)

# do not return words that were searched
search_word_indexes = [self.word2index[word] for word in list(keywords) + list(keywords_neg)]
res_indexes = [index for index, word_ind in enumerate(word_indexes)
if word_ind not in search_word_indexes][:num_words]
word_indexes = word_indexes[res_indexes]
word_scores = word_scores[res_indexes]

words = [self.vocab[word] for word in word_indexes]

return words, word_scores
Expand Down Expand Up @@ -1593,8 +1604,18 @@ def search_documents_by_documents(self, doc_ids, num_docs, doc_ids_neg=None, ret
doc_vecs = [self.document_vectors[ind] for ind in doc_indexes]
doc_vecs_neg = [self.document_vectors[ind] for ind in doc_indexes_neg]
combined_vector = self._get_combined_vec(doc_vecs, doc_vecs_neg)

num_res = min(num_docs + len(doc_indexes) + len(doc_indexes_neg),
self._get_document_vectors().shape[0])

# don't return documents that were searched
search_doc_indexes = list(doc_indexes) + list(doc_indexes_neg)
doc_indexes, doc_scores = self._search_vectors_by_vector(self._get_document_vectors(),
combined_vector, num_docs)
combined_vector, num_res)
res_indexes = [index for index, doc_ind in enumerate(doc_indexes)
if doc_ind not in search_doc_indexes][:num_docs]
doc_indexes = doc_indexes[res_indexes]
doc_scores = doc_scores[res_indexes]

doc_ids = self._get_document_ids(doc_indexes)

Expand Down

0 comments on commit fa13a67

Please sign in to comment.