Skip to content

Commit

Permalink
Rerank search results with cross-encoder when using an inference server
Browse files Browse the repository at this point in the history
If an inference server is being used, we can expect the cross encoder
to be running fast enough to rerank search results by default
  • Loading branch information
debanjum committed Mar 10, 2024
1 parent 44c8d09 commit 53d4024
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
21 changes: 11 additions & 10 deletions src/khoj/processor/embeddings.py
Expand Up @@ -33,8 +33,11 @@ def __init__(
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None

def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]

Expand Down Expand Up @@ -62,11 +65,10 @@ def embed_with_api(self, docs):
return response.json()["embeddings"]

def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
if "huggingface" not in target_url:
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
Expand All @@ -93,12 +95,11 @@ def __init__(
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None

def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
Expand Down
8 changes: 6 additions & 2 deletions src/khoj/search_type/text_search.py
Expand Up @@ -177,9 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):


def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# Rerank results if explicitly requested or if device has GPU
# Rerank results if explicitly requested, if can use inference server or if device has GPU
# AND if we have more than one result
rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1
rank_results = (
rank_results
or state.cross_encoder_model[search_model_name].inference_server_enabled()
or state.device.type != "cpu"
) and len(list(hits)) > 1

# Score all retrieved entries using the cross-encoder
if rank_results:
Expand Down

0 comments on commit 53d4024

Please sign in to comment.