Skip to content

Commit

Permalink
Adding the prefix feature to Marqo V2 (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed May 7, 2024
1 parent 6380014 commit a44f06c
Show file tree
Hide file tree
Showing 23 changed files with 899 additions and 41 deletions.
10 changes: 3 additions & 7 deletions src/marqo/api/models/embed_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
https://pydantic-docs.helpmanual.io/usage/types/#enums-and-choices
"""
import pydantic
from pydantic import BaseModel, root_validator, Field
from typing import Union, List, Dict, Optional, Any

from marqo.tensor_search import validation
from marqo.tensor_search.enums import SearchMethod
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.score_modifiers_object import ScoreModifier
from marqo.tensor_search.models.search import SearchContext, SearchContextTensor
from marqo.tensor_search.models.api_models import BaseMarqoModel
from marqo.api.exceptions import InvalidArgError
from marqo.core.models.marqo_index import MarqoIndex
from marqo.core.embed.embed import EmbedContentType



class EmbedRequest(BaseMarqoModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
image_download_headers: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
content_type: Optional[EmbedContentType] = EmbedContentType.Query

@pydantic.validator('content')
def validate_content(cls, value):
Expand Down
31 changes: 29 additions & 2 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from timeit import default_timer as timer
from typing import List, Optional, Union, Dict
from enum import Enum

import pydantic

Expand All @@ -10,10 +11,14 @@
from marqo.tensor_search.models.search import Qidx
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.core.utils.prefix import determine_text_prefix, DeterminePrefixContentType
from marqo.vespa.vespa_client import VespaClient

logger = get_logger(__name__)

class EmbedContentType(str, Enum):
Query = "query"
Document = "document"

class Embed:
def __init__(self, vespa_client: VespaClient, index_management: IndexManagement, default_device: str):
Expand All @@ -30,7 +35,8 @@ def validate_default_device(cls, value):
def embed_content(
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None
model_auth: Optional[ModelAuth] = None,
content_type: Optional[EmbedContentType] = EmbedContentType.Query
) -> Dict:
"""
Use the index's model to embed the content
Expand All @@ -39,6 +45,14 @@ def embed_content(
List of embeddings corresponding to the content. If content is a list, the return list will be in the same order.
If content is a string, the return list will only have 1 item.
"""
"""
NOTE: PARAMETER: content_type
3 Options: ‘query’, ‘document’, None. Defaults to ‘query’.
1. If the user wants to use the default text_query_prefix, leave it as ‘query’.
2. If the user wants to use the default text_chunk_prefix, leave it as ‘document’.
3. If the user wants a custom prefix, they must put it in the content itself.
"""


# TODO: Remove this config constructor once vectorise pipeline doesn't need it. Just pass the vespa client
# and index management objects.
Expand All @@ -52,6 +66,7 @@ def embed_content(
if device is None:
device = self.default_device


# Content validation is done in API model layer
t0 = timer()

Expand All @@ -67,6 +82,17 @@ def embed_content(
else:
raise base_exceptions.InternalError(f"Content type {type(content)} is not supported for embed endpoint.")

# Decide on the prefix
if content_type == EmbedContentType.Query:
prefix = determine_text_prefix(None, marqo_index, DeterminePrefixContentType.TextQuery)
elif content_type == EmbedContentType.Document:
prefix = determine_text_prefix(None, marqo_index, DeterminePrefixContentType.TextChunk)
elif content_type is None:
prefix = ""
else:
# use [item.value for item in list(EmbedContentType)], but formatted not as a list
raise ValueError(f"Invalid content_type: {content_type}. Must be {', '.join([item.value for item in list(EmbedContentType)])}, or None.")

queries = []
for content_entry in content_list:
queries.append(
Expand All @@ -75,7 +101,8 @@ def embed_content(
q=content_entry,
index=marqo_index,
image_download_headers=image_download_headers,
modelAuth=model_auth
modelAuth=model_auth,
text_query_prefix=prefix
# TODO: Check if it's fine that we leave out the other parameters
)
)
Expand Down
6 changes: 5 additions & 1 deletion src/marqo/core/models/marqo_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ class MarqoIndex(ImmutableStrictBaseModel, ABC):
vector_numeric_type: VectorNumericType
hnsw_config: HnswConfig
marqo_version: str
override_text_query_prefix: Optional[str]
override_text_chunk_prefix: Optional[str]
created_at: int = pydantic.Field(gt=0)
updated_at: int = pydantic.Field(gt=0)
_cache: Dict[str, Any] = PrivateAttr()
Expand Down Expand Up @@ -262,7 +264,7 @@ def parse_obj(cls, obj: Any) -> Union['UnstructuredMarqoIndex', 'StructuredMarqo
else:
raise ValidationError(f"Invalid index type {obj['type']}")

raise ValidationError(f"Index type not found in {obj}")
raise ValidationError(f"Index type not found in {obj}")

def _cache_or_get(self, key: str, func):
if key not in self._cache:
Expand All @@ -274,6 +276,7 @@ class UnstructuredMarqoIndex(MarqoIndex):
type = IndexType.Unstructured
treat_urls_and_pointers_as_images: bool
filter_string_max_length: int


@classmethod
def _valid_type(cls) -> IndexType:
Expand All @@ -284,6 +287,7 @@ class StructuredMarqoIndex(MarqoIndex):
type = IndexType.Structured
fields: List[Field] # all fields, including tensor fields
tensor_fields: List[TensorField]


def __init__(self, **data):
super().__init__(**data)
Expand Down
2 changes: 2 additions & 0 deletions src/marqo/core/models/marqo_index_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class MarqoIndexRequest(ImmutableStrictBaseModel, ABC):
distance_metric: marqo_index.DistanceMetric
vector_numeric_type: marqo_index.VectorNumericType
hnsw_config: marqo_index.HnswConfig
override_text_query_prefix: Optional[str]
override_text_chunk_prefix: Optional[str]
marqo_version: str
created_at: int
updated_at: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def _generate_document_section(self, schema_name: str) -> (List[str], Structured
created_at=self._index_request.created_at,
updated_at=self._index_request.updated_at,
fields=fields,
tensor_fields=tensor_fields
tensor_fields=tensor_fields,
override_text_query_prefix=self._index_request.override_text_query_prefix,
override_text_chunk_prefix=self._index_request.override_text_chunk_prefix
)

return document, marqo_index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def _generate_unstructured_marqo_index(self, schema_name: str) -> UnstructuredMa
updated_at=self._index_request.updated_at,
treat_urls_and_pointers_as_images=self._index_request.treat_urls_and_pointers_as_images,
filter_string_max_length=self._index_request.filter_string_max_length,
override_text_query_prefix=self._index_request.override_text_query_prefix,
override_text_chunk_prefix=self._index_request.override_text_chunk_prefix
)

@classmethod
Expand Down
57 changes: 57 additions & 0 deletions src/marqo/core/utils/prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

from marqo.s2_inference.errors import InvalidModelPropertiesError
from marqo.core.models.marqo_index import *
from marqo.s2_inference.s2_inference import get_model_properties_from_registry

class DeterminePrefixContentType(Enum):
TextChunk = "text_chunk_prefix"
TextQuery = "text_query_prefix"

def determine_text_prefix(request_level_prefix: str, marqo_index: MarqoIndex, prefix_type: DeterminePrefixContentType) -> str:
"""
Determines the text prefix to be used for chunking text fields or search queries.
This prefix will be added before each text chunk or query to enhance processing accuracy.
Logic:
1. Prioritize request-level prefix
2. If not provided, use settings based on prefix_type
3. If still not provided, use model_properties defined prefix
4. Return "" if no prefix is found, handling is expected by the caller
Args:
request_level_prefix (str): The prefix provided in the request
index_settings (IndexSettings): The settings object containing prefix information
prefix_type (str): Either "text_query_prefix" or "text_chunk_prefix"
Returns:
str: The determined prefix, or None if no prefix is found
"""
if request_level_prefix is not None:
return request_level_prefix


# Check for the presence of the textChunkPrefix or textQueryPrefix in the MarqoIndex object.
if prefix_type == DeterminePrefixContentType.TextQuery and marqo_index.override_text_query_prefix is not None:
return marqo_index.override_text_query_prefix
elif prefix_type == DeterminePrefixContentType.TextChunk and marqo_index.override_text_chunk_prefix is not None:
return marqo_index.override_text_chunk_prefix

# Fallback to model_properties defined prefix
try:
model_properties = marqo_index.model.properties if marqo_index.model is not None else None
if model_properties is None and marqo_index.model is not None:
model_properties = get_model_properties_from_registry(marqo_index.model.name)

if model_properties is not None:
default_prefix = model_properties.get(prefix_type.value)
return default_prefix if default_prefix is not None else ""
else:
raise ValueError(f"Model properties not found for model: {marqo_index.model.name}")
except InvalidModelPropertiesError as e:
raise e

# If no prefix is found, return empty string ""
return ""



32 changes: 32 additions & 0 deletions src/marqo/s2_inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,83 +511,107 @@ def _get_hf_properties() -> Dict:
"tokens": 192,
"type": "hf",
"model_size": 0.1342,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-base":
{"name": 'intfloat/e5-base',
"dimensions": 768,
"tokens": 192,
"type": "hf",
"model_size": 0.438,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-large":
{"name": 'intfloat/e5-large',
"dimensions": 1024,
"tokens": 192,
"type": "hf",
"model_size": 1.3,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-large-unsupervised":
{"name": 'intfloat/e5-large-unsupervised',
"dimensions": 1024,
"tokens": 128,
"type": "hf",
"model_size": 1.3,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-base-unsupervised":
{"name": 'intfloat/e5-base-unsupervised',
"dimensions": 768,
"tokens": 128,
"type": "hf",
"model_size": 0.438,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-small-unsupervised":
{"name": 'intfloat/e5-small-unsupervised',
"dimensions": 384,
"tokens": 128,
"type": "hf",
"model_size": 0.134,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/multilingual-e5-small":
{"name": 'intfloat/multilingual-e5-small',
"dimensions": 384,
"tokens": 512,
"type": "hf",
"model_size": 0.471,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/multilingual-e5-base":
{"name": 'intfloat/multilingual-e5-base',
"dimensions": 768,
"tokens": 512,
"type": "hf",
"model_size": 1.11,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/multilingual-e5-large":
{"name": 'intfloat/multilingual-e5-large',
"dimensions": 1024,
"tokens": 512,
"type": "hf",
"model_size": 2.24,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-small-v2":
{"name": 'intfloat/e5-small-v2',
"dimensions": 384,
"tokens": 512,
"type": "hf",
"model_size": 0.134,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-base-v2":
{"name": 'intfloat/e5-base-v2',
"dimensions": 768,
"tokens": 512,
"type": "hf",
"model_size": 0.438,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
"hf/e5-large-v2":
{"name": 'intfloat/e5-large-v2',
"dimensions": 1024,
"tokens": 512,
"type": "hf",
"model_size": 1.34,
"text_query_prefix": "query: ",
"text_chunk_prefix": "passage: ",
"notes": ""},
}
return HF_MODEL_PROPERTIES
Expand Down Expand Up @@ -674,6 +698,14 @@ def _get_sbert_test_properties() -> Dict:
"tokens":128,
"type":"test",
"notes": ""},
"test_prefix":
{"name": "sentence-transformers/all-MiniLM-L6-v1",
"dimensions": 16,
"tokens":128,
"type":"test",
"text_query_prefix": "test query: ",
"text_chunk_prefix": "test passage: ",
"notes": ""},
}
return TEST_MODEL_PROPERTIES

Expand Down
17 changes: 17 additions & 0 deletions src/marqo/s2_inference/processing/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,22 @@ def split_text(text: str, split_by: str = 'sentence', split_length: int = 2, spl
# assume a uniform seperator when reconstructing the sentences
text_splits = _reconstruct_multi_list(segments, seperator)


return text_splits

def prefix_text_chunks(text_splits: List[str], text_chunk_prefix: str) -> List[str]:
"""
Args:
text_splits (List[str]): Chunk list without prefixes
text_chunk_prefix (str): Prefix to add before each text chunk
Returns:
List[str]: Text splits with prefix prepended to each
"""
if not text_chunk_prefix:
return text_splits

# Note that with this we directly concatenate the prefix to the text chunk
# So we should make sure that there is a space between the prefix and the text
# In text_chunk_prefix
return [text_chunk_prefix + text for text in text_splits]

4 changes: 3 additions & 1 deletion src/marqo/tensor_search/add_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from PIL.ImageFile import ImageFile

from marqo.s2_inference import clip_utils
from marqo.s2_inference.errors import InvalidModelPropertiesError
from marqo.tensor_search.telemetry import RequestMetricsStore, RequestMetrics
from marqo.tensor_search import enums
from marqo.tensor_search import constants
import marqo.core.exceptions as core_exceptions
import marqo.exceptions as base_exceptions
from marqo.core.models.marqo_index import *

from concurrent.futures import ThreadPoolExecutor


Expand Down Expand Up @@ -193,4 +195,4 @@ def determine_document_dict_field_type(field_name: str, field_content, mappings:
else:
raise base_exceptions.InternalError(f"Invalid dict field type: '{mappings[field_name]['type']}' for field: '{field_name}' in mappings. Must be one of {[t.value for t in enums.MappingsObjectType]}")
else:
return None
return None

0 comments on commit a44f06c

Please sign in to comment.