Skip to content

Commit

Permalink
Prefix refactor and backwards compatibility (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed May 9, 2024
1 parent 9ed292d commit 4b13a22
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 51 deletions.
10 changes: 7 additions & 3 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ def embed_content(
raise base_exceptions.InternalError(f"Content type {type(content)} is not supported for embed endpoint.")

# Decide on the prefix
if content_type == EmbedContentType.Query:
prefix = determine_text_prefix(None, marqo_index, DeterminePrefixContentType.TextQuery)

# For backwards compatibility
if marqo_index.model.text_query_prefix is None or marqo_index.model.text_chunk_prefix is None:
prefix = ""
elif content_type == EmbedContentType.Query:
prefix = marqo_index.model.get_text_query_prefix()
elif content_type == EmbedContentType.Document:
prefix = determine_text_prefix(None, marqo_index, DeterminePrefixContentType.TextChunk)
prefix = marqo_index.model.get_text_chunk_prefix()
elif content_type is None:
prefix = ""
else:
Expand Down
13 changes: 13 additions & 0 deletions src/marqo/core/index_management/index_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def create_index(self, marqo_index_request: MarqoIndexRequest) -> MarqoIndex:
logger.debug('Marqo config does not exist. Configuring Vespa as part of index creation')
self._add_marqo_config(app)

# Populate the prefix fields if they are None
if marqo_index_request.model.text_query_prefix is None:
marqo_index_request.model.text_query_prefix = marqo_index_request.model.get_default_text_query_prefix()
if marqo_index_request.model.text_chunk_prefix is None:
marqo_index_request.model.text_chunk_prefix = marqo_index_request.model.get_default_text_chunk_prefix()

vespa_schema = vespa_schema_factory(marqo_index_request)
schema, marqo_index = vespa_schema.generate_schema()

Expand Down Expand Up @@ -161,6 +167,13 @@ def batch_create_indexes(self, marqo_index_requests: List[MarqoIndexRequest]) ->
if self.index_exists(index.name):
raise IndexExistsError(f"Index {index.name} already exists")

# Populate the prefix fields if they are None
for index in marqo_index_requests:
if index.model.text_query_prefix is None:
index.model.text_query_prefix = index.model.get_default_text_query_prefix()
if index.model.text_chunk_prefix is None:
index.model.text_chunk_prefix = index.model.get_default_text_chunk_prefix()

schema_responses = [
vespa_schema_factory(index).generate_schema() # Tuple (schema, MarqoIndex)
for index in marqo_index_requests
Expand Down
55 changes: 49 additions & 6 deletions src/marqo/core/models/marqo_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class Model(StrictBaseModel):
name: str
properties: Optional[Dict[str, Any]]
custom: bool = False
text_query_prefix: Optional[str]
text_chunk_prefix: Optional[str]

@root_validator(pre=False)
def validate_custom_properties(cls, values):
Expand Down Expand Up @@ -188,6 +190,45 @@ def _update_model_properties_from_registry(self) -> None:
raise InvalidArgumentError(
f'Invalid model properties for model={model_name}. Reason: {e}.'
)

def get_text_query_prefix(self, request_level_prefix: Optional[str] = None) -> str:
if request_level_prefix is not None:
return request_level_prefix

# For backwards compatibility. Since older versions of Marqo did not have a text_query_prefix field,
# we need to return an empty string if the model does not have a text_query_prefix.
# We know that the value of text_query_prefix is None in old indexes since the model was not populated
# from the registry.
if self.text_query_prefix is None:
return ""

# Else return the model default as populated during initialization
return self.text_query_prefix

def get_text_chunk_prefix(self, request_level_prefix: Optional[str] = None) -> str:
if request_level_prefix is not None:
return request_level_prefix

# For backwards compatibility. Since older versions of Marqo did not have a text_chunk_prefix field,
# we need to return an empty string if the model does not have a text_chunk_prefix.
# We know that the value of text_chunk_prefix is None in old indexes since the model was not populated
# from the registry.
if self.text_chunk_prefix is None:
return ""

# Else return the model default as populated during initialization
return self.text_chunk_prefix

def get_default_text_query_prefix(self) -> Optional[str]:
return self._get_default_prefix("text_query_prefix")

def get_default_text_chunk_prefix(self) -> Optional[str]:
return self._get_default_prefix("text_chunk_prefix")

def _get_default_prefix(self, prefix_type: str) -> Optional[str]:
model_properties = self.get_properties()
default_prefix = model_properties.get(prefix_type)
return default_prefix


class MarqoIndex(ImmutableStrictBaseModel, ABC):
Expand All @@ -205,8 +246,6 @@ class MarqoIndex(ImmutableStrictBaseModel, ABC):
vector_numeric_type: VectorNumericType
hnsw_config: HnswConfig
marqo_version: str
override_text_query_prefix: Optional[str]
override_text_chunk_prefix: Optional[str]
created_at: int = pydantic.Field(gt=0)
updated_at: int = pydantic.Field(gt=0)
_cache: Dict[str, Any] = PrivateAttr()
Expand Down Expand Up @@ -264,19 +303,23 @@ def parse_obj(cls, obj: Any) -> Union['UnstructuredMarqoIndex', 'StructuredMarqo
else:
raise ValidationError(f"Invalid index type {obj['type']}")

raise ValidationError(f"Index type not found in {obj}")
raise ValidationError(f"Index type not found in {obj}")

def _cache_or_get(self, key: str, func):
if key not in self._cache:
self._cache[key] = func()
return self._cache[key]




class UnstructuredMarqoIndex(MarqoIndex):
type = IndexType.Unstructured
treat_urls_and_pointers_as_images: bool
filter_string_max_length: int


def __init__(self, **data):
super().__init__(**data)

@classmethod
def _valid_type(cls) -> IndexType:
Expand All @@ -287,7 +330,6 @@ class StructuredMarqoIndex(MarqoIndex):
type = IndexType.Structured
fields: List[Field] # all fields, including tensor fields
tensor_fields: List[TensorField]


def __init__(self, **data):
super().__init__(**data)
Expand Down Expand Up @@ -499,7 +541,8 @@ def validate_structured_field(values, marqo_index: bool) -> None:
f'{FieldType.MultimodalCombination.value}'
)

if FieldFeature.LexicalSearch in features and type not in [FieldType.Text, FieldType.ArrayText, FieldType.CustomVector]:
if FieldFeature.LexicalSearch in features and type not in [FieldType.Text, FieldType.ArrayText,
FieldType.CustomVector]:
raise ValueError(
f'{name}: Field with {FieldFeature.LexicalSearch.value} feature must be of type '
f'{FieldType.Text.value} or {FieldType.ArrayText.value}'
Expand Down
2 changes: 0 additions & 2 deletions src/marqo/core/models/marqo_index_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class MarqoIndexRequest(ImmutableStrictBaseModel, ABC):
distance_metric: marqo_index.DistanceMetric
vector_numeric_type: marqo_index.VectorNumericType
hnsw_config: marqo_index.HnswConfig
override_text_query_prefix: Optional[str]
override_text_chunk_prefix: Optional[str]
marqo_version: str
created_at: int
updated_at: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ def _generate_document_section(self, schema_name: str) -> (List[str], Structured
updated_at=self._index_request.updated_at,
fields=fields,
tensor_fields=tensor_fields,
override_text_query_prefix=self._index_request.override_text_query_prefix,
override_text_chunk_prefix=self._index_request.override_text_chunk_prefix
)

return document, marqo_index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def _generate_unstructured_marqo_index(self, schema_name: str) -> UnstructuredMa
updated_at=self._index_request.updated_at,
treat_urls_and_pointers_as_images=self._index_request.treat_urls_and_pointers_as_images,
filter_string_max_length=self._index_request.filter_string_max_length,
override_text_query_prefix=self._index_request.override_text_query_prefix,
override_text_chunk_prefix=self._index_request.override_text_chunk_prefix
)

@classmethod
Expand Down
17 changes: 9 additions & 8 deletions src/marqo/tensor_search/models/index_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class IndexSettings(StrictBaseModel):
filterStringMaxLength: Optional[int]
model: str = 'hf/e5-base-v2'
modelProperties: Optional[Dict[str, Any]]
textQueryPrefix: Optional[str] = None
textChunkPrefix: Optional[str] = None
normalizeEmbeddings: bool = True
textPreprocessing: core.TextPreProcessing = core.TextPreProcessing(
splitLength=2,
Expand All @@ -41,8 +43,7 @@ class IndexSettings(StrictBaseModel):
m=16
)
)
overrideTextQueryPrefix: Optional[str] = None
overrideTextChunkPrefix: Optional[str] = None


@root_validator(pre=True)
def validate_field_names(cls, values):
Expand Down Expand Up @@ -91,7 +92,9 @@ def to_marqo_index_request(self, index_name: str) -> MarqoIndexRequest:
model=core.Model(
name=self.model,
properties=self.modelProperties,
custom=self.modelProperties is not None
custom=self.modelProperties is not None,
text_query_prefix=self.textQueryPrefix,
text_chunk_prefix=self.textChunkPrefix
),
normalize_embeddings=self.normalizeEmbeddings,
text_preprocessing=self.textPreprocessing,
Expand All @@ -104,8 +107,6 @@ def to_marqo_index_request(self, index_name: str) -> MarqoIndexRequest:
marqo_version=version.get_version(),
created_at=time.time(),
updated_at=time.time(),
override_text_query_prefix=self.overrideTextQueryPrefix,
override_text_chunk_prefix=self.overrideTextChunkPrefix
)
elif self.type == core.IndexType.Unstructured:
if self.allFields is not None:
Expand All @@ -132,7 +133,9 @@ def to_marqo_index_request(self, index_name: str) -> MarqoIndexRequest:
model=core.Model(
name=self.model,
properties=self.modelProperties,
custom=self.modelProperties is not None
custom=self.modelProperties is not None,
text_query_prefix=self.textQueryPrefix,
text_chunk_prefix=self.textChunkPrefix
),
normalize_embeddings=self.normalizeEmbeddings,
text_preprocessing=self.textPreprocessing,
Expand All @@ -142,8 +145,6 @@ def to_marqo_index_request(self, index_name: str) -> MarqoIndexRequest:
hnsw_config=self.annParameters.parameters,
treat_urls_and_pointers_as_images=self.treatUrlsAndPointersAsImages,
filter_string_max_length=self.filterStringMaxLength,
override_text_query_prefix=self.overrideTextQueryPrefix,
override_text_chunk_prefix=self.overrideTextChunkPrefix,
marqo_version=version.get_version(),
created_at=time.time(),
updated_at=time.time()
Expand Down
13 changes: 7 additions & 6 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.vespa.exceptions import VespaStatusError
from marqo.vespa.models import VespaDocument, FeedBatchResponse, QueryResult
from marqo.core.utils.prefix import determine_text_prefix, DeterminePrefixContentType

logger = get_logger(__name__)

Expand Down Expand Up @@ -139,7 +138,8 @@ def _add_documents_unstructured(config: Config, add_docs_params: AddDocsParams,
total_vectorise_time = 0
batch_size = len(add_docs_params.docs)
image_repo = {}
text_chunk_prefix = determine_text_prefix(add_docs_params.text_chunk_prefix, marqo_index, DeterminePrefixContentType.TextChunk)

text_chunk_prefix = marqo_index.model.get_text_chunk_prefix(add_docs_params.text_chunk_prefix)

docs, doc_ids = config.document.remove_duplicated_documents(add_docs_params.docs)

Expand Down Expand Up @@ -562,7 +562,8 @@ def _add_documents_structured(config: Config, add_docs_params: AddDocsParams, ma
total_vectorise_time = 0
batch_size = len(add_docs_params.docs) # use length before deduplication
image_repo = {}
text_chunk_prefix = determine_text_prefix(add_docs_params.text_chunk_prefix, marqo_index, DeterminePrefixContentType.TextChunk)

text_chunk_prefix = marqo_index.model.get_text_chunk_prefix(add_docs_params.text_chunk_prefix)

# Deduplicate docs, keep the latest
docs, doc_ids = config.document.remove_duplicated_documents(add_docs_params.docs)
Expand Down Expand Up @@ -1708,7 +1709,7 @@ def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors
def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSearchQueryEntity]:
prefixed_queries = []
for q in queries:
text_query_prefix = determine_text_prefix(q.text_query_prefix, q.index, DeterminePrefixContentType.TextQuery)
text_query_prefix = q.index.model.get_text_query_prefix(q.text_query_prefix)

if q.q is None:
prefixed_q = q.q
Expand Down Expand Up @@ -1830,7 +1831,7 @@ def _vector_text_search(
marqo_index = index_meta_cache.get_index(config=config, index_name=index_name)

# Determine the text query prefix
text_query_prefix = determine_text_prefix(text_query_prefix, marqo_index, DeterminePrefixContentType.TextQuery)
text_query_prefix = marqo_index.model.get_text_query_prefix(text_query_prefix)

queries = [BulkSearchQueryEntity(
q=query, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.TENSOR, limit=result_count,
Expand Down Expand Up @@ -1943,7 +1944,7 @@ def vectorise_multimodal_combination_field_unstructured(field: str,
doc_id: str, device: str, marqo_index: UnstructuredMarqoIndex,
image_repo, field_map: dict,
model_auth: Optional[ModelAuth] = None,
text_chunk_prefix: str = ""
text_chunk_prefix: str = None
):
'''
This function is used to vectorise multimodal combination field.
Expand Down
2 changes: 2 additions & 0 deletions tests/core/index_management/test_get_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_custom_settings(self):
# Get unstructured custom settings
retrieved_index = self.config.index_management.get_index(self.unstructured_custom_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
print(f"retrieved_settings: {retrieved_settings}")
self.assertEqual(retrieved_settings, expected_unstructured_custom_settings)

with self.subTest("Structured index custom settings"):
Expand Down Expand Up @@ -203,5 +204,6 @@ def test_custom_settings(self):
# Get unstructured default settings
retrieved_index = self.config.index_management.get_index(self.structured_custom_index.name)
retrieved_settings = IndexSettings.from_marqo_index(retrieved_index).dict(exclude_none=True, by_alias=True)
print(f"retrieved_settings: {retrieved_settings}")
self.assertEqual(retrieved_settings, expected_structured_custom_settings)

34 changes: 34 additions & 0 deletions tests/core/index_management/test_index_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,40 @@ def test_create_index_indexExists_fails(self):
with self.assertRaises(IndexExistsError):
self.index_management.create_index(marqo_index_request)

def test_create_index_text_prefix_defaults_successful(self):
"""
Text prefix defaults are set correctly
"""
index_name = 'a' + str(uuid.uuid4()).replace('-', '')
marqo_index_request = self.structured_marqo_index_request(
name=index_name,
model=Model(
name='test_prefix'
),
distance_metric=DistanceMetric.PrenormalizedAngular,
vector_numeric_type=VectorNumericType.Float,
hnsw_config=HnswConfig(ef_construction=100, m=16),
fields=[
FieldRequest(name='title', type=FieldType.Text, features=[FieldFeature.LexicalSearch]),
],
tensor_fields=[]
)

# test create_index
index = self.index_management.create_index(marqo_index_request)
self.assertEqual(index.model.text_query_prefix, "test query: ")
self.assertEqual(index.model.text_chunk_prefix, "test passage: ")

self.index_management.delete_index(index)

# test batch_create_index
indexes = self.index_management.batch_create_indexes([marqo_index_request])
self.assertEqual(indexes[0].model.text_query_prefix, "test query: ")
self.assertEqual(indexes[0].model.text_chunk_prefix, "test passage: ")

self.index_management.delete_index(indexes[0])


def test_get_marqo_version_successful(self):
"""
get_marqo_version returns current version
Expand Down

0 comments on commit 4b13a22

Please sign in to comment.