diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index 6936e33ea30..6745ceb126d 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -6,6 +6,7 @@ import time import logging import numpy as np +import torch from copy import deepcopy from collections import defaultdict from scipy.spatial.distance import cosine @@ -15,6 +16,7 @@ from haystack.errors import DuplicateDocumentError from haystack.document_stores import BaseDocumentStore from haystack.document_stores.base import get_batches_from_generator +from haystack.modeling.utils import initialize_device_settings logger = logging.getLogger(__name__) @@ -35,6 +37,8 @@ def __init__( similarity: str = "dot_product", progress_bar: bool = True, duplicate_documents: str = "overwrite", + use_gpu: bool = True, + scoring_batch_size: int = 500000, ): """ :param index: The documents are scoped to an index attribute that can be used when writing, querying, @@ -53,6 +57,13 @@ def __init__( overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. + :param use_gpu: Whether to use a GPU or the CPU for calculating embedding similarity. + Falls back to CPU if no GPU is available. + :param scoring_batch_size: Batch size of documents to calculate similarity for. Very small batch sizes are inefficent. + Very large batch sizes can overrun GPU memory. In general you want to make sure + you have at least `embedding_dim`*`scoring_batch_size`*4 bytes available in GPU memory. + Since the data is originally stored in CPU memory there is little risk of overruning memory + when running on CPU. """ # save init parameters to enable export of component config as YAML self.set_config( @@ -75,6 +86,11 @@ def __init__( self.similarity = similarity self.progress_bar = progress_bar self.duplicate_documents = duplicate_documents + self.use_gpu = use_gpu + self.scoring_batch_size = scoring_batch_size + + self.devices, _ = initialize_device_settings(use_cuda=self.use_gpu) + self.main_device = self.devices[0] def write_documents( self, @@ -193,6 +209,85 @@ def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> Li documents = [self.indexes[index][id] for id in ids] return documents + def get_scores_torch(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]: + """ + Calculate similarity scores between query embedding and a list of documents using torch. + + :param query_emb: Embedding of the query (e.g. gathered from DPR) + :param document_to_search: List of documents to compare `query_emb` against. + """ + query_emb = torch.tensor(query_emb, dtype=torch.float).to(self.main_device) + if len(query_emb.shape) == 1: + query_emb = query_emb.unsqueeze(dim=0) + + doc_embeds = np.array([doc.embedding for doc in document_to_search]) + doc_embeds = torch.as_tensor(doc_embeds, dtype=torch.float) + if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1: + doc_embeds = doc_embeds.unsqueeze(dim=0) + elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0: + return [] + + if self.similarity == "cosine": + # cosine similarity is just a normed dot product + query_emb_norm = torch.norm(query_emb, dim=1) + query_emb = torch.div(query_emb, query_emb_norm) + + doc_embeds_norms = torch.norm(doc_embeds, dim=1) + doc_embeds = torch.div(doc_embeds.T, doc_embeds_norms).T + + curr_pos = 0 + scores = [] + while curr_pos < len(doc_embeds): + doc_embeds_slice = doc_embeds[curr_pos : curr_pos + self.scoring_batch_size] + doc_embeds_slice = doc_embeds_slice.to(self.main_device) + with torch.no_grad(): + slice_scores = torch.matmul(doc_embeds_slice, query_emb.T).cpu() + slice_scores = slice_scores.squeeze(dim=1) + slice_scores = slice_scores.numpy().tolist() + + scores.extend(slice_scores) + curr_pos += self.scoring_batch_size + + return scores + + def get_scores_numpy(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]: + """ + Calculate similarity scores between query embedding and a list of documents using numpy. + + :param query_emb: Embedding of the query (e.g. gathered from DPR) + :param document_to_search: List of documents to compare `query_emb` against. + """ + if len(query_emb.shape) == 1: + query_emb = np.expand_dims(query_emb, 0) + + doc_embeds = np.array([doc.embedding for doc in document_to_search]) + if len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 1: + doc_embeds = doc_embeds.unsqueeze(dim=0) + elif len(doc_embeds.shape) == 1 and doc_embeds.shape[0] == 0: + return [] + + if self.similarity == "cosine": + # cosine similarity is just a normed dot product + query_emb_norm = np.apply_along_axis(np.linalg.norm, 1, query_emb) + query_emb_norm = np.expand_dims(query_emb_norm, 1) + query_emb = np.divide(query_emb, query_emb_norm) + + doc_embeds_norms = np.apply_along_axis(np.linalg.norm, 1, doc_embeds) + doc_embeds_norms = np.expand_dims(doc_embeds_norms, 1) + doc_embeds = np.divide(doc_embeds, doc_embeds_norms) + + scores = np.dot(query_emb, doc_embeds.T)[0].tolist() + + return scores + + def get_scores(self, query_emb: np.ndarray, document_to_search: List[Document]) -> List[float]: + if self.main_device.type == "cuda": + scores = self.get_scores_torch(query_emb, document_to_search) + else: + scores = self.get_scores_numpy(query_emb, document_to_search) + + return scores + def query_by_embedding( self, query_emb: np.ndarray, @@ -224,17 +319,15 @@ def query_by_embedding( return [] document_to_search = self.get_all_documents(index=index, filters=filters, return_embedding=True) + scores = self.get_scores(query_emb, document_to_search) + candidate_docs = [] - for doc in document_to_search: + for doc, score in zip(document_to_search, scores): curr_meta = deepcopy(doc.meta) new_document = Document(id=doc.id, content=doc.content, meta=curr_meta, embedding=doc.embedding) new_document.embedding = doc.embedding if return_embedding is True else None - if self.similarity == "dot_product": - score = np.dot(query_emb, doc.embedding) - elif self.similarity == "cosine": - # cosine similarity score = 1 - cosine distance - score = 1 - cosine(query_emb, doc.embedding) + new_document.embedding = doc.embedding if return_embedding is True else None new_document.score = self.finalize_raw_score(score, self.similarity) candidate_docs.append(new_document)