Skip to content

Commit

Permalink
[MLflow] Renaming vector search index to retriever (#12051)
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed May 19, 2024
1 parent f7c420f commit 3779805
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 67 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@
("py:class", "mlflow.models.signature.ModelSignature"),
("py:class", "mlflow.models.resources.Resource"),
("py:class", "mlflow.models.resources.ResourceType"),
("py:class", "mlflow.models.dependencies_schema.set_vector_search_schema"),
("py:class", "mlflow.models.dependencies_schema.set_retriever_schema"),
("py:class", "mlflow.metrics.genai.base.EvaluationExample"),
("py:class", "mlflow.models.evaluation.base.EvaluationMetric"),
("py:class", "MlflowInferableDataset"),
Expand Down
4 changes: 2 additions & 2 deletions mlflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
For details, see `MLflow Models <../models.html>`_.
"""
from mlflow.models.dependencies_schema import set_vector_search_schema
from mlflow.models.dependencies_schema import set_retriever_schema
from mlflow.models.evaluation import (
EvaluationArtifact,
EvaluationMetric,
Expand Down Expand Up @@ -58,7 +58,7 @@
"EvaluationResult",
"get_model_info",
"set_model",
"set_vector_search_schema",
"set_retriever_schema",
"list_evaluators",
"MetricThreshold",
"build_docker",
Expand Down
86 changes: 46 additions & 40 deletions mlflow/models/dependencies_schema.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional

from mlflow.models.resources import ResourceType
from mlflow.utils.annotations import experimental

DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY = "__databricks_vector_search_primary_key__"
DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN = "__databricks_vector_search_text_column__"
DATABRICKS_VECTOR_SEARCH_DOC_URI = "__databricks_vector_search_doc_uri__"
DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS = "__databricks_vector_search_other_columns__"
DATABRICKS_RETRIEVER_PRIMARY_KEY = "__databricks_retriever_primary_key__"
DATABRICKS_RETRIEVER_TEXT_COLUMN = "__databricks_retriever_text_column__"
DATABRICKS_RETRIEVER_DOC_URI = "__databricks_retriever_doc_uri__"
DATABRICKS_RETRIEVER_OTHER_COLUMNS = "__databricks_retriever_other_columns__"


class DependenciesSchemasType(Enum):
"""
Enum to define the different types of dependencies schemas for the model.
"""

RETRIEVERS = "retrievers"


@experimental
def set_vector_search_schema(
def set_retriever_schema(
primary_key: str,
text_column: str,
doc_uri: Optional[str] = None,
other_columns: Optional[List[str]] = None,
):
"""
After defining your vector store in a Python file or notebook, call
set_vector_search_schema() so that, when MLflow retrieves documents during
set_retriever_schema() so that, when MLflow retrieves documents during
model inference, MLflow can interpret the fields in each retrieved document and
determine which fields correspond to the document text, document URI, etc.
Expand All @@ -36,67 +44,65 @@ def set_vector_search_schema(
.. code-block:: Python
:caption: Example
from mlflow.models import set_vector_search_schema
from mlflow.models import set_retriever_schema
set_vector_search_schema(
set_retriever_schema(
primary_key="chunk_id",
text_column="chunk_text",
doc_uri="doc_uri",
other_columns=["title"],
)
"""
globals()[DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY] = primary_key
globals()[DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN] = text_column
globals()[DATABRICKS_VECTOR_SEARCH_DOC_URI] = doc_uri
globals()[DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS] = other_columns or []
globals()[DATABRICKS_RETRIEVER_PRIMARY_KEY] = primary_key
globals()[DATABRICKS_RETRIEVER_TEXT_COLUMN] = text_column
globals()[DATABRICKS_RETRIEVER_DOC_URI] = doc_uri
globals()[DATABRICKS_RETRIEVER_OTHER_COLUMNS] = other_columns or []


def _get_vector_search_schema():
def _get_retriever_schema():
"""
Get the vector search schema defined by the user.
Returns:
VectorSearchIndex: The vector search index schema.
"""
if not globals().get(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None) or not globals().get(
DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None
if not globals().get(DATABRICKS_RETRIEVER_PRIMARY_KEY, None) or not globals().get(
DATABRICKS_RETRIEVER_TEXT_COLUMN, None
):
return []

return [
VectorSearchIndexSchema(
name="vector_search_index",
primary_key=globals().get(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None),
text_column=globals().get(DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None),
doc_uri=globals().get(DATABRICKS_VECTOR_SEARCH_DOC_URI, None),
other_columns=globals().get(DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS, None),
RetrieverSchema(
name="retriever",
primary_key=globals().get(DATABRICKS_RETRIEVER_PRIMARY_KEY, None),
text_column=globals().get(DATABRICKS_RETRIEVER_TEXT_COLUMN, None),
doc_uri=globals().get(DATABRICKS_RETRIEVER_DOC_URI, None),
other_columns=globals().get(DATABRICKS_RETRIEVER_OTHER_COLUMNS, None),
)
]


def _clear_vector_search_schema():
def _clear_retriever_schema():
"""
Clear the vector search schema defined by the user.
"""
globals().pop(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_DOC_URI, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS, None)
globals().pop(DATABRICKS_RETRIEVER_PRIMARY_KEY, None)
globals().pop(DATABRICKS_RETRIEVER_TEXT_COLUMN, None)
globals().pop(DATABRICKS_RETRIEVER_DOC_URI, None)
globals().pop(DATABRICKS_RETRIEVER_OTHER_COLUMNS, None)


def _clear_dependencies_schema():
"""
Clear all the dependencies schema defined by the user.
"""
# Clear the vector search schema
_clear_vector_search_schema()
_clear_retriever_schema()


@contextmanager
def _get_dependencies_schema():
dependencies_schema = DependenciesSchemas(
vector_search_index_schemas=_get_vector_search_schema()
)
dependencies_schema = DependenciesSchemas(retriever_schemas=_get_retriever_schema())
try:
yield dependencies_schema
finally:
Expand All @@ -112,7 +118,7 @@ class Schema(ABC):
type (ResourceType): The type of the schema.
"""

type: ResourceType
type: DependenciesSchemasType

@abstractmethod
def to_dict(self):
Expand All @@ -131,7 +137,7 @@ def from_dict(cls, data: Dict[str, str]):


@dataclass
class VectorSearchIndexSchema(Schema):
class RetrieverSchema(Schema):
"""
Define vector search index resource to serve a model.
Expand All @@ -151,7 +157,7 @@ def __init__(
doc_uri: Optional[str] = None,
other_columns: Optional[List[str]] = None,
):
super().__init__(type=ResourceType.VECTOR_SEARCH_INDEX)
super().__init__(type=DependenciesSchemasType.RETRIEVERS)
self.name = name
self.primary_key = primary_key
self.text_column = text_column
Expand Down Expand Up @@ -184,17 +190,17 @@ def from_dict(cls, data: Dict[str, str]):

@dataclass
class DependenciesSchemas:
vector_search_index_schemas: List[VectorSearchIndexSchema] = field(default_factory=list)
retriever_schemas: List[RetrieverSchema] = field(default_factory=list)

def to_dict(self) -> Dict[str, Dict[ResourceType, List[Dict]]]:
if not self.vector_search_index_schemas:
def to_dict(self) -> Dict[str, Dict[DependenciesSchemasType, List[Dict]]]:
if not self.retriever_schemas:
return None

return {
"dependencies_schemas": {
ResourceType.VECTOR_SEARCH_INDEX.value: [
index.to_dict()[ResourceType.VECTOR_SEARCH_INDEX.value][0]
for index in self.vector_search_index_schemas
DependenciesSchemasType.RETRIEVERS.value: [
index.to_dict()[DependenciesSchemasType.RETRIEVERS.value][0]
for index in self.retriever_schemas
],
}
}
4 changes: 2 additions & 2 deletions tests/langchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

from mlflow.models import ModelConfig, set_model, set_vector_search_schema
from mlflow.models import ModelConfig, set_model, set_retriever_schema

base_config = ModelConfig(development_config="tests/langchain/config.yml")

Expand Down Expand Up @@ -74,7 +74,7 @@ def _llm_type(self) -> str:
)

set_model(retrieval_chain)
set_vector_search_schema(
set_retriever_schema(
primary_key="primary-key",
text_column="text-column",
doc_uri="doc-uri",
Expand Down
5 changes: 3 additions & 2 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IS_PICKLE_SERIALIZATION_RESTRICTED,
)
from mlflow.models import Model
from mlflow.models.dependencies_schema import DependenciesSchemasType
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex, Resource
from mlflow.models.signature import ModelSignature, Schema, infer_signature
from mlflow.pyfunc.context import Context
Expand Down Expand Up @@ -2397,10 +2398,10 @@ def test_save_load_chain_as_code(chain_model_signature):
"serving_endpoint": [{"name": "fake-endpoint"}]
}
assert reloaded_model.metadata["dependencies_schemas"] == {
"vector_search_index": [
DependenciesSchemasType.RETRIEVERS.value: [
{
"doc_uri": "doc-uri",
"name": "vector_search_index",
"name": "retriever",
"other_columns": ["column1", "column2"],
"primary_key": "primary-key",
"text_column": "text-column",
Expand Down
41 changes: 21 additions & 20 deletions tests/models/test_dependencies_schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from mlflow.models.dependencies_schema import (
DependenciesSchemas,
VectorSearchIndexSchema,
DependenciesSchemasType,
RetrieverSchema,
_get_dependencies_schema,
_get_vector_search_schema,
set_vector_search_schema,
_get_retriever_schema,
set_retriever_schema,
)


def test_vector_search_index_creation():
vsi = VectorSearchIndexSchema(
def test_retriever_creation():
vsi = RetrieverSchema(
name="index-name",
primary_key="primary-key",
text_column="text-column",
Expand All @@ -22,16 +23,16 @@ def test_vector_search_index_creation():
assert vsi.other_columns == ["column1", "column2"]


def test_vector_search_index_to_dict():
vsi = VectorSearchIndexSchema(
def test_retriever_to_dict():
vsi = RetrieverSchema(
name="index-name",
primary_key="primary-key",
text_column="text-column",
doc_uri="doc-uri",
other_columns=["column1", "column2"],
)
expected_dict = {
"vector_search_index": [
DependenciesSchemasType.RETRIEVERS.value: [
{
"name": "index-name",
"primary_key": "primary-key",
Expand All @@ -44,15 +45,15 @@ def test_vector_search_index_to_dict():
assert vsi.to_dict() == expected_dict


def test_vector_search_index_from_dict():
def test_retriever_from_dict():
data = {
"name": "index-name",
"primary_key": "primary-key",
"text_column": "text-column",
"doc_uri": "doc-uri",
"other_columns": ["column1", "column2"],
}
vsi = VectorSearchIndexSchema.from_dict(data)
vsi = RetrieverSchema.from_dict(data)
assert vsi.name == "index-name"
assert vsi.primary_key == "primary-key"
assert vsi.text_column == "text-column"
Expand All @@ -61,17 +62,17 @@ def test_vector_search_index_from_dict():


def test_dependencies_schemas_to_dict():
vsi = VectorSearchIndexSchema(
vsi = RetrieverSchema(
name="index-name",
primary_key="primary-key",
text_column="text-column",
doc_uri="doc-uri",
other_columns=["column1", "column2"],
)
schema = DependenciesSchemas(vector_search_index_schemas=[vsi])
schema = DependenciesSchemas(retriever_schemas=[vsi])
expected_dict = {
"dependencies_schemas": {
"vector_search_index": [
DependenciesSchemasType.RETRIEVERS.value: [
{
"name": "index-name",
"primary_key": "primary-key",
Expand All @@ -85,19 +86,19 @@ def test_dependencies_schemas_to_dict():
assert schema.to_dict() == expected_dict


def test_set_vector_search_schema_creation():
set_vector_search_schema(
def test_set_retriever_schema_creation():
set_retriever_schema(
primary_key="primary-key",
text_column="text-column",
doc_uri="doc-uri",
other_columns=["column1", "column2"],
)
with _get_dependencies_schema() as schema:
assert schema.vector_search_index_schemas[0].to_dict() == {
"vector_search_index": [
assert schema.retriever_schemas[0].to_dict() == {
DependenciesSchemasType.RETRIEVERS.value: [
{
"doc_uri": "doc-uri",
"name": "vector_search_index",
"name": "retriever",
"other_columns": ["column1", "column2"],
"primary_key": "primary-key",
"text_column": "text-column",
Expand All @@ -107,9 +108,9 @@ def test_set_vector_search_schema_creation():

with _get_dependencies_schema() as schema:
assert schema.to_dict() is None
assert _get_vector_search_schema() == []
assert _get_retriever_schema() == []


def test_set_vector_search_schema_empty_creation():
def test_set_retriever_schema_empty_creation():
with _get_dependencies_schema() as schema:
assert schema.to_dict() is None

0 comments on commit 3779805

Please sign in to comment.