Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rerank Search Results by Default on GPU machines #668

Merged
merged 3 commits into from Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -50,7 +50,7 @@ dependencies = [
"pyyaml == 6.0",
"rich >= 13.3.1",
"schedule == 1.1.0",
"sentence-transformers == 2.3.1",
"sentence-transformers == 2.5.1",
"transformers >= 4.28.0",
"torch == 2.0.1",
"uvicorn == 0.17.6",
Expand Down
19 changes: 10 additions & 9 deletions src/interface/desktop/search.html
Expand Up @@ -192,16 +192,17 @@
});
}

let debounceTimeout;
function incrementalSearch(event) {
type = 'all';
// Search with reranking on 'Enter'
if (event.key === 'Enter') {
search(rerank=true);
}
// Limit incremental search to text types
else if (type !== "image") {
search(rerank=false);
}
// Run incremental search only after waitTime passed since the last key press
let waitTime = 300;
clearTimeout(debounceTimeout);
debounceTimeout = setTimeout(() => {
type = 'all';
// Search with reranking on 'Enter'
let should_rerank = event.key === 'Enter';
search(rerank=should_rerank);
}, waitTime);
}

async function populate_type_dropdown() {
Expand Down
19 changes: 10 additions & 9 deletions src/khoj/interface/web/search.html
Expand Up @@ -193,16 +193,17 @@
});
}

let debounceTimeout;
function incrementalSearch(event) {
type = document.getElementById("type").value;
// Search with reranking on 'Enter'
if (event.key === 'Enter') {
search(rerank=true);
}
// Limit incremental search to text types
else if (type !== "image") {
search(rerank=false);
}
// Run incremental search only after waitTime passed since the last key press
let waitTime = 300;
clearTimeout(debounceTimeout);
debounceTimeout = setTimeout(() => {
type = document.getElementById("type").value;
// Search with reranking on 'Enter'
let should_rerank = event.key === 'Enter';
search(rerank=should_rerank);
}, waitTime);
}

function populate_type_dropdown() {
Expand Down
21 changes: 11 additions & 10 deletions src/khoj/processor/embeddings.py
Expand Up @@ -33,8 +33,11 @@ def __init__(
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None

def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]

Expand Down Expand Up @@ -62,11 +65,10 @@ def embed_with_api(self, docs):
return response.json()["embeddings"]

def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
if "huggingface" not in target_url:
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
Expand All @@ -93,12 +95,11 @@ def __init__(
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key

def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None

def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
if (
self.api_key is not None
and self.inference_endpoint is not None
and "huggingface" in self.inference_endpoint
):
if self.inference_server_enabled() and "huggingface" in self.inference_endpoint:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
Expand Down
9 changes: 7 additions & 2 deletions src/khoj/search_type/text_search.py
Expand Up @@ -177,8 +177,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):


def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Rerank results if explicitly requested, if can use inference server or if device has GPU
# AND if we have more than one result
rank_results = (
rank_results
or state.cross_encoder_model[search_model_name].inference_server_enabled()
or state.device.type != "cpu"
) and len(list(hits)) > 1

# Score all retrieved entries using the cross-encoder
if rank_results:
Expand Down
7 changes: 6 additions & 1 deletion src/khoj/utils/helpers.py
Expand Up @@ -331,7 +331,12 @@ def batcher(iterable, max_n):
yield (x for x in chunk if x is not None)


def is_env_var_true(env_var: str, default: str = "false") -> bool:
"""Get state of boolean environment variable"""
return os.getenv(env_var, default).lower() == "true"


def in_debug_mode():
"""Check if Khoj is running in debug mode.
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
return os.getenv("KHOJ_DEBUG", "false").lower() == "true"
return is_env_var_true("KHOJ_DEBUG")