Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up query_by_embedding in InMemoryDocumentStore. #2091

Merged
merged 9 commits into from
Feb 4, 2022
105 changes: 99 additions & 6 deletions haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import logging
import numpy as np
import torch
from copy import deepcopy
from collections import defaultdict
from scipy.spatial.distance import cosine
Expand All @@ -15,6 +16,7 @@
from haystack.errors import DuplicateDocumentError
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.modeling.utils import initialize_device_settings


logger = logging.getLogger(__name__)
Expand All @@ -35,6 +37,8 @@ def __init__(
similarity: str = "dot_product",
progress_bar: bool = True,
duplicate_documents: str = "overwrite",
use_gpu: bool = True,
scoring_batch_size: int = 500000,
):
"""
:param index: The documents are scoped to an index attribute that can be used when writing, querying,
Expand All @@ -53,6 +57,13 @@ def __init__(
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param use_gpu: Whether to use a GPU or the CPU for calculating embedding similarity.
Falls back to CPU if no GPU is available.
:param scoring_batch_size: Batch size of documents to calculate similarity for. Very small batch sizes are inefficent.
Very large batch sizes can overrun GPU memory. In general you want to make sure
you have at least `embedding_dim`*`scoring_batch_size`*4 bytes available in GPU memory.
Since the data is originally stored in CPU memory there is little risk of overruning memory
when running on CPU.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
Expand All @@ -75,6 +86,11 @@ def __init__(
self.similarity = similarity
self.progress_bar = progress_bar
self.duplicate_documents = duplicate_documents
self.use_gpu = use_gpu
self.scoring_batch_size = scoring_batch_size

self.devices, _ = initialize_device_settings(use_cuda=self.use_gpu)
self.main_device = self.devices[0]

def write_documents(
self,
Expand Down Expand Up @@ -193,6 +209,85 @@ def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> Li
documents = [self.indexes[index][id] for id in ids]
return documents

def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
"""
Calculate similarity scores between query embedding and a list of documents using torch.

:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param document_to_search: List of documents to compare `query_emb` against.
"""
query_emb = torch.tensor(query_emb, dtype=torch.float).to(self.main_device)
if len(query_emb.shape) == 1:
query_emb = query_emb.unsqueeze(dim=0)

doc_embeds = np.array([doc.embedding for doc in document_to_search])
doc_embeds = torch.as_tensor(doc_embeds, dtype=torch.float)
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
doc_embeds = doc_embeds.unsqueeze(dim=0)
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
return []

if self.similarity == "cosine":
# cosine similarity is just a normed dot product
query_emb_norm = torch.norm(query_emb, dim=1)
query_emb = torch.div(query_emb, query_emb_norm)

doc_embeds_norms = torch.norm(doc_embeds, dim=1)
doc_embeds = torch.div(doc_embeds.T, doc_embeds_norms).T

curr_pos = 0
scores = []
while curr_pos < len(doc_embeds):
doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size]
doc_embeds_slice = doc_embeds_slice.to(self.main_device)
with torch.no_grad():
slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu()
slice_scores = slice_scores.squeeze(dim=1)
slice_scores = slice_scores.numpy().tolist()

scores.extend(slice_scores)
curr_pos += self.scoring_batch_size

return scores

def get_scores_numpy(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
"""
Calculate similarity scores between query embedding and a list of documents using numpy.

:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param document_to_search: List of documents to compare `query_emb` against.
"""
if len(query_emb.shape) == 1:
query_emb = np.expand_dims(query_emb, 0)

doc_embeds = np.array([doc.embedding for doc in document_to_search])
if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1:
doc_embeds = doc_embeds.unsqueeze(dim=0)
elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0:
return []

if self.similarity == "cosine":
# cosine similarity is just a normed dot product
query_emb_norm = np.apply_along_axis(np.linalg.norm, 1, query_emb)
query_emb_norm = np.expand_dims(query_emb_norm, 1)
query_emb = np.divide(query_emb, query_emb_norm)

doc_embeds_norms = np.apply_along_axis(np.linalg.norm, 1, doc_embeds)
doc_embeds_norms = np.expand_dims(doc_embeds_norms, 1)
doc_embeds = np.divide(doc_embeds, doc_embeds_norms)

scores = np.dot(query_emb, doc_embeds.T)[0].tolist()

return scores

def get_scores(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]:
if self.main_device.type == "cuda":
scores = self.get_scores_torch(query_emb, document_to_search)
else:
scores = self.get_scores_numpy(query_emb, document_to_search)

return scores

def query_by_embedding(
self,
query_emb: np.ndarray,
Expand Down Expand Up @@ -224,17 +319,15 @@ def query_by_embedding(
return []

document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True)
scores = self.get_scores(query_emb, document_to_search)

candidate_docs = []
for doc in document_to_search:
for doc, score in zip(document_to_search, scores):
curr_meta = deepcopy(doc.meta)
new_document = Document(id=doc.id, content=doc.content, meta=curr_meta, embedding=doc.embedding)
new_document.embedding = doc.embedding if return_embedding is True else None

if self.similarity == "dot_product":
score = np.dot(query_emb, doc.embedding)
elif self.similarity == "cosine":
# cosine similarity score = 1 - cosine distance
score = 1 - cosine(query_emb, doc.embedding)
new_document.embedding = doc.embedding if return_embedding is True else None
new_document.score = self.finalize_raw_score(score, self.similarity)
candidate_docs.append(new_document)

Expand Down