Skip to content

Commit

Permalink
Rerank search results with cross-encoder when using an inference server
Browse files Browse the repository at this point in the history
If an inference server is being used, we can expect the cross encoder
to be running fast enough to rerank search results by default
  • Loading branch information
debanjum committed Mar 10, 2024
1 parent 5fd0f20 commit 208b91a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 4 additions & 5 deletions src/khoj/processor/embeddings.py
Expand Up @@ -93,12 +93,11 @@ def __init__(
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key

def is_inference_endpoint_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None

def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
if self.is_inference_endpoint_enabled() and "huggingface" in self.inference_endpoint:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
Expand Down
8 changes: 6 additions & 2 deletions src/khoj/routers/api.py
Expand Up @@ -146,8 +146,12 @@ async def search(
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)

# Rerank results if explicitly requested or if device has GPU
rerank = r or state.device.type != "cpu"
# Rerank results if explicitly requested, if device has GPU or if an inference server is being used
rerank = (
r
or state.device.type != "cpu"
or state.cross_encoder_model[search_model.name].is_inference_endpoint_enabled()
)

# Sort results across all content types and take top results
results = text_search.rerank_and_sort_results(
Expand Down

0 comments on commit 208b91a

Please sign in to comment.