Skip to content

Commit

Permalink
Configurable search timeout (#843)
Browse files Browse the repository at this point in the history
Cherry picking #813
  • Loading branch information
farshidz committed May 22, 2024
1 parent aa6572d commit 26a5dd1
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/unit_test_200gb_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ on:
push:
branches:
- mainline
- releases/*
paths-ignore:
- '**.md'
pull_request:
branches:
- mainline
- releases/*
paths-ignore:
- '**.md'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _to_vespa_tensor_query(self, marqo_query: MarqoTensorQuery) -> Dict[str, Any

if not marqo_query.approximate:
query['ranking.softtimeout.enable'] = False
query['timeout'] = '300s'
query['timeout'] = 300 * 1000 # 5 minutes

return query

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _to_vespa_tensor_query(self, marqo_query: MarqoTensorQuery) -> Dict[str, Any

if not marqo_query.approximate:
query['ranking.softtimeout.enable'] = False
query['timeout'] = '300s'
query['timeout'] = 300 * 1000 # 5 minutes

return query

Expand Down
1 change: 1 addition & 0 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def generate_config() -> config.Config:
document_url=utils.read_env_vars_and_defaults(EnvVars.VESPA_DOCUMENT_URL),
pool_size=utils.read_env_vars_and_defaults_ints(EnvVars.VESPA_POOL_SIZE),
content_cluster_name=utils.read_env_vars_and_defaults(EnvVars.VESPA_CONTENT_CLUSTER_NAME),
default_search_timeout_ms=utils.read_env_vars_and_defaults_ints(EnvVars.VESPA_SEARCH_TIMEOUT_MS),
)
index_management = IndexManagement(vespa_client)
return config.Config(vespa_client, index_management)
Expand Down
1 change: 1 addition & 0 deletions src/marqo/tensor_search/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def default_env_vars() -> dict:
EnvVars.VESPA_DOCUMENT_URL: "http://localhost:8080",
EnvVars.VESPA_CONTENT_CLUSTER_NAME: "content_default",
EnvVars.VESPA_POOL_SIZE: 10,
EnvVars.VESPA_SEARCH_TIMEOUT_MS: 1000,
EnvVars.MARQO_MAX_INDEX_FIELDS: None,
EnvVars.MARQO_MAX_DOC_BYTES: 100000,
EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: 10000,
Expand Down
1 change: 1 addition & 0 deletions src/marqo/tensor_search/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class EnvVars:
VESPA_QUERY_URL = "VESPA_QUERY_URL"
VESPA_DOCUMENT_URL = "VESPA_DOCUMENT_URL"
VESPA_CONTENT_CLUSTER_NAME = "VESPA_CONTENT_CLUSTER_NAME"
VESPA_SEARCH_TIMEOUT_MS = "VESPA_SEARCH_TIMEOUT_MS"
VESPA_POOL_SIZE = "VESPA_POOL_SIZE"
MARQO_MAX_INDEX_FIELDS = "MARQO_MAX_INDEX_FIELDS"
MARQO_MAX_DOC_BYTES = "MARQO_MAX_DOC_BYTES"
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.2"
__version__ = "2.2.3"


def get_version() -> str:
Expand Down
12 changes: 10 additions & 2 deletions src/marqo/vespa/vespa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, current_generation: int, wanted_generation: int, converged: b
self.converged = converged

def __init__(self, config_url: str, document_url: str, query_url: str,
content_cluster_name: str, pool_size: int = 10):
content_cluster_name: str, default_search_timeout_ms: int = 1000, pool_size: int = 10):
"""
Create a VespaClient object.
Args:
Expand All @@ -52,6 +52,7 @@ def __init__(self, config_url: str, document_url: str, query_url: str,
self.http_client = httpx.Client(
limits=httpx.Limits(max_keepalive_connections=pool_size, max_connections=pool_size)
)
self.default_search_timeout_ms = default_search_timeout_ms
self.content_cluster_name = content_cluster_name

def close(self):
Expand Down Expand Up @@ -142,7 +143,7 @@ def wait_for_application_convergence(self, timeout: int = 120) -> None:
raise VespaError(f"Vespa application did not converge within {timeout} seconds")

def query(self, yql: str, hits: int = 10, ranking: str = None, model_restrict: str = None,
query_features: Dict[str, Any] = None, **kwargs) -> QueryResult:
query_features: Dict[str, Any] = None, timeout: int = None, **kwargs) -> QueryResult:
"""
Query Vespa.
Args:
Expand All @@ -166,6 +167,13 @@ def query(self, yql: str, hits: int = 10, ranking: str = None, model_restrict: s
**query_features_list,
**kwargs
}

# Use default timeout if not already set.
if timeout:
query['timeout'] = f"{timeout}ms"
else:
query['timeout'] = f"{self.default_search_timeout_ms}ms"

query = {key: value for key, value in query.items() if value is not None}

logger.debug(f'Query: {query}')
Expand Down
48 changes: 45 additions & 3 deletions tests/tensor_search/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from marqo.tensor_search.enums import EnvVars
from marqo.vespa import exceptions as vespa_exceptions
from tests.marqo_test import MarqoTestCase
import importlib
import sys
import os


class ApiTests(MarqoTestCase):
Expand All @@ -34,7 +37,7 @@ def test_add_or_replace_documents_tensor_fields(self):
)
self.assertEqual(response.status_code, 200)
mock_add_documents.assert_called_once()

def test_memory(self):
"""
Test that the memory endpoint returns the expected keys when debug API is enabled.
Expand All @@ -58,6 +61,43 @@ def test_memory_disabled_403(self):
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_DEBUG_API: 'FALSE'}):
response = self.client.get("/memory")
self.assertEqual(response.status_code, 403)


class TestApiCustomEnvVars(MarqoTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

unstructured_index_request = cls.unstructured_marqo_index_request()
structured_index_request = cls.structured_marqo_index_request(
fields=[
FieldRequest(name='field1', type=FieldType.Text),
FieldRequest(name='field2', type=FieldType.Text)
],
tensor_fields=['field1']
)

cls.indexes = cls.create_indexes([unstructured_index_request, structured_index_request])

cls.unstructured_index = cls.indexes[0]
cls.structured_index = cls.indexes[1]

def test_search_timeout_short_timer_fails(self):
# Set up the test API client with the correct env vars set
with mock.patch.dict(os.environ, {"VESPA_SEARCH_TIMEOUT_MS": "1"}):
importlib.reload(sys.modules['marqo.tensor_search.api'])
# VespaClient will be created with default timeout of 1ms
self.client = TestClient(api.app)

for index in [self.unstructured_index, self.structured_index]:
with self.subTest(index=index.name):
res = self.client.post("/indexes/" + index.name + "/search?device=cpu", json={
"q": "irrelevant"
})
# The search request must timeout, since the timeout is set to 1ms
self.assertEqual(res.status_code, 504)
self.assertEqual(res.json()["code"], "vector_store_timeout")
self.assertEqual(res.json()["type"], "invalid_request")


class TestApiErrors(MarqoTestCase):
Expand All @@ -72,6 +112,7 @@ class TestApiErrors(MarqoTestCase):
def setUpClass(cls) -> None:
super().setUpClass()

unstructured_index_request = cls.unstructured_marqo_index_request()
structured_index_request = cls.structured_marqo_index_request(
fields=[
FieldRequest(name='field1', type=FieldType.Text),
Expand All @@ -80,9 +121,10 @@ def setUpClass(cls) -> None:
tensor_fields=['field1']
)

cls.indexes = cls.create_indexes([structured_index_request])
cls.indexes = cls.create_indexes([unstructured_index_request, structured_index_request])

cls.structured_index = cls.indexes[0]
cls.unstructured_index = cls.indexes[0]
cls.structured_index = cls.indexes[1]

def setUp(self):
self.client = TestClient(api.app)
Expand Down
23 changes: 23 additions & 0 deletions tests/vespa/test_vespa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,29 @@ def modified_post(*args, **kwargs):
yql="select * from sources * where title contains 'Title 1';"
)

def test_default_search_timeout_fails(self):
"""
VespaTimeoutError error is raised when VespaClient is created with a default timeout of 1ms.
This will fail even if query 'timeout' isn't set, since the default timeout will be used.
"""
query_client = VespaClient("http://localhost:8080", "http://localhost:8080",
"http://localhost:8080", "content_default",
default_search_timeout_ms=1)

def pass_through_post(*args, **kwargs):
return httpx.post(*args, **kwargs)

with patch.object(
httpx.Client, "post",
wraps=pass_through_post
) as mock_post:
with self.assertRaisesStrict(VespaTimeoutError):
query_client.query(
yql="select * from sources * where title contains 'Title 1';"
)
# Ensure that post was called with correct timeout
self.assertEqual(mock_post.call_args.kwargs['json']['timeout'], '1ms')

def test_query_softDoom_fails(self):
"""
VespaTimeoutError error is raised when Vespa responds with a soft doom error.
Expand Down

0 comments on commit 26a5dd1

Please sign in to comment.