Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali K committed Apr 30, 2024
1 parent 00ab7fc commit 7eb17a6
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 46 deletions.
26 changes: 0 additions & 26 deletions src/marqo/core/index_management/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from enum import Enum
from decimal import Decimal
from marqo.tensor_search.models.index_settings import IndexSettings
from pydantic import ValidationError
import json
Expand Down Expand Up @@ -29,27 +27,3 @@ def validate_settings_object(index_name, settings_json) -> bool:
except Exception as e:
logger.error(f'Exception while validating index {index_name}: {e}')
raise e


def convert_marqo_request_to_dict(index_settings):
"""Converts a MarqoIndexRequest to a dictionary.
Returns
A dictionary representation of the MarqoIndexRequest
"""
index_settings_dict = index_settings.dict(exclude_none=True)
return convert_enums_to_values(index_settings_dict)


def convert_enums_to_values(data):
if isinstance(data, dict):
return {key: convert_enums_to_values(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_enums_to_values(element) for element in data]
elif isinstance(data, Enum):
if isinstance(data.value, float):
return Decimal(str(data.value))
return data.value
elif isinstance(data, float):
return Decimal(str(data))
else:
return data
5 changes: 2 additions & 3 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,10 @@ def memory():
return memory_profiler.get_memory_profile()


@app.post('/validate/index')
@app.post('/validate/{index_name}')
@utils.enable_ops_api()
def schema_validation(index_name, settings_object):
def schema_validation(index_name: str, settings_object: str):
try:
settings_object = json.loads(settings_object)
validate_settings_object(index_name, settings_object)
return JSONResponse(
content={
Expand Down
108 changes: 91 additions & 17 deletions tests/tensor_search/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import uuid
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -37,7 +38,7 @@ def test_add_or_replace_documents_tensor_fields(self):
)
self.assertEqual(response.status_code, 200)
mock_add_documents.assert_called_once()

def test_memory(self):
"""
Test that the memory endpoint returns the expected keys when debug API is enabled.
Expand All @@ -61,7 +62,95 @@ def test_memory_disabled_403(self):
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_DEBUG_API: 'FALSE'}):
response = self.client.get("/memory")
self.assertEqual(response.status_code, 403)



class ValidationApiTests(MarqoTestCase):
def setUp(self):
self.client = TestClient(api.app)

def test_schema_validation_defaultDisabled(self):
"""
Test that the schema_validation endpoint returns 403 by default.
"""
data = {
"type": "structured",
"allFields": [],
"tensorFields": []
}
index_name = "test-index"
response = self.client.post(f"/validate/{index_name}?settings_object={data}")
self.assertEqual(response.status_code, 403)

def test_ops_api_disabled_403(self):
"""
Test that the ops-api endpoint returns 403 when debug API is disabled explicitly.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_OPS_API: 'FALSE'}):
data = {
"type": "structured",
"allFields": [],
"tensorFields": [],
"settings_object": {}
}
index_name = "test-index"
response = self.client.post(f"/validate/{index_name}?settings_object={data}")
self.assertEqual(response.status_code, 403)

def test_ops_api_200(self):
"""
Test that the ops-api endpoint returns 200 when debug API is enabled.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_OPS_API: 'TRUE'}):
data = {
"treatUrlsAndPointersAsImages": False,
"model": "hf/e5-large",
"normalizeEmbeddings": True,
"textPreprocessing": {
"splitLength": 2,
"splitOverlap": 0,
"splitMethod": "sentence",
},
"imagePreprocessing": {"patchMethod": None},
"annParameters": {
"spaceType": "euclidean",
"parameters": {"efConstruction": 128, "m": 16},
},
"type": "unstructured",
}
index_name = "test-index"
settings_json = json.dumps(data)
response = self.client.post(f"/validate/{index_name}?settings_object={settings_json}")
self.assertEqual(response.json(), {'validated': True, 'index': 'test-index'})

def test_ops_api_400(self):
"""
Test that the ops-api endpoint returns 400 when debug API is enabled and the input is invalid.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_OPS_API: 'TRUE'}):
data = {
"treatUrlsAndPointersAsImages": False,
"model": "hf/e5-large",
"normalizeEmbeddings": True,
"textPreprocessing": {
"splitLength": 2,
"splitOverlap": 0,
"splitMethod": "sentence",
},
"imagePreprocessing": {"patchMethod": None},
"annParameters": {
"spaceType": "euclidean",
"parameters": {"efConstruction": 128, "m": 16},
},
"type": "unknown" # invalid type
}
index_name = "test-index"
settings_json = json.dumps(data)
response = self.client.post(f"/validate/{index_name}?settings_object={settings_json}")
self.assertEqual(response.status_code, 400)
self.assertFalse(response.json()['validated'])
self.assertEqual(response.json()['index'], 'test-index')
self.assertTrue(response.json()['validation_error'].startswith('1 validation error for IndexSettings'))


class TestApiCustomEnvVars(MarqoTestCase):
@classmethod
Expand Down Expand Up @@ -99,21 +188,6 @@ def test_search_timeout_short_timer_fails(self):
self.assertEqual(res.json()["code"], "vector_store_timeout")
self.assertEqual(res.json()["type"], "invalid_request")

def test_schema_validation_defaultDisabled(self):
"""
Test that the schema_validation endpoint returns 403 by default.
"""
response = self.client.post("/validate/index")
self.assertEqual(response.status_code, 403)

def test_ops_api_disabled_403(self):
"""
Test that the ops-api endpoint returns 403 when debug API is disabled explicitly.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_OPS_API: 'FALSE'}):
response = self.client.post("/validate/index")
self.assertEqual(response.status_code, 403)


class TestApiErrors(MarqoTestCase):
"""
Expand Down

0 comments on commit 7eb17a6

Please sign in to comment.