Skip to content

Commit

Permalink
[MLflow] Update mlflow langchain metadata to write dependencies_schem…
Browse files Browse the repository at this point in the history
…as (#12045)

Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed May 19, 2024
1 parent e18ba3f commit fa6e185
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 72 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
("py:class", "mlflow.models.signature.ModelSignature"),
("py:class", "mlflow.models.resources.Resource"),
("py:class", "mlflow.models.resources.ResourceType"),
("py:class", "mlflow.models.dependencies_schema.set_vector_search_schema"),
("py:class", "mlflow.metrics.genai.base.EvaluationExample"),
("py:class", "mlflow.models.evaluation.base.EvaluationMetric"),
("py:class", "MlflowInferableDataset"),
Expand Down
17 changes: 16 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
register_pydantic_v1_serializer_cm,
)
from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info
from mlflow.models.dependencies_schema import _clear_dependencies_schema, _get_dependencies_schema
from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH, MODEL_CONFIG
from mlflow.models.resources import _ResourceBuilder
from mlflow.models.signature import _infer_signature_from_input_example
Expand Down Expand Up @@ -323,6 +324,13 @@ def load_retriever(persist_directory):
if metadata is not None:
mlflow_model.metadata = metadata

with _get_dependencies_schema() as dependencies_schema:
schema = dependencies_schema.to_dict()
if schema is not None:
if mlflow_model.metadata is None:
mlflow_model.metadata = {}
mlflow_model.metadata.update(schema)

streamable = isinstance(lc_model, lc_runnables_types())

if not isinstance(model_code_path, str):
Expand Down Expand Up @@ -843,7 +851,14 @@ def _load_model_from_local_fs(local_model_path):
os.path.basename(flavor_code_path),
)

return _load_model_code_path(code_path, model_config)
try:
model = _load_model_code_path(code_path, model_config)
finally:
# We would like to clean up the dependencies schema which is set to global
# after loading the mode to avoid the schema being used in the next model loading
_clear_dependencies_schema()

return model
else:
_add_code_from_conf_to_system_path(local_model_path, flavor_conf)
with patch_langchain_type_to_cls_dict():
Expand Down
50 changes: 1 addition & 49 deletions mlflow/langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from functools import lru_cache
from importlib.util import find_spec
from typing import Callable, List, NamedTuple, Optional
from typing import Callable, NamedTuple

import cloudpickle
import yaml
Expand Down Expand Up @@ -688,51 +688,3 @@ def register_pydantic_v1_serializer_cm():
yield
finally:
unregister_pydantic_serializer()


DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY = "__databricks_vector_search_primary_key__"
DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN = "__databricks_vector_search_text_column__"
DATABRICKS_VECTOR_SEARCH_DOC_URI = "__databricks_vector_search_doc_uri__"
DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS = "__databricks_vector_search_other_columns__"


def set_vector_search_schema(
primary_key: str,
text_column: str = "",
doc_uri: str = "",
other_columns: Optional[List[str]] = None,
):
"""
After defining your vector store in a Python file or notebook, call
set_vector_search_schema() so that we can correctly map the vector index
columns. These columns would be used during tracing and in the review UI.
Args:
primary_key: The primary key of the vector index.
text_column: The name of the text column to use for the embeddings.
doc_uri: The name of the column that contains the document URI.
other_columns: A list of other columns that are part of the vector index
that need to be retrieved during trace logging.
Note: Make sure the text column specified is in the index.
Example:
.. code-block:: python
from mlflow.langchain.utils import set_vector_search_schema
set_vector_search_schema(
primary_key="chunk_id",
text_column="chunk_text",
doc_uri="doc_uri",
other_columns=["title"],
)
"""
globals()[DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY] = primary_key
globals()[DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN] = text_column
globals()[DATABRICKS_VECTOR_SEARCH_DOC_URI] = doc_uri
globals()[DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS] = other_columns or []


def get_databricks_vector_search_key(key):
return globals().get(key)
2 changes: 2 additions & 0 deletions mlflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
For details, see `MLflow Models <../models.html>`_.
"""
from mlflow.models.dependencies_schema import set_vector_search_schema
from mlflow.models.evaluation import (
EvaluationArtifact,
EvaluationMetric,
Expand Down Expand Up @@ -57,6 +58,7 @@
"EvaluationResult",
"get_model_info",
"set_model",
"set_vector_search_schema",
"list_evaluators",
"MetricThreshold",
"build_docker",
Expand Down
57 changes: 48 additions & 9 deletions mlflow/models/dependencies_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Dict, List, Optional

Expand Down Expand Up @@ -32,9 +33,8 @@ def set_vector_search_schema(
that need to be retrieved during trace logging.
Note: Make sure the text column specified is in the index.
Example:
.. code-block:: python
.. code-block:: Python
:caption: Example
from mlflow.models import set_vector_search_schema
Expand All @@ -58,13 +58,49 @@ def _get_vector_search_schema():
Returns:
VectorSearchIndex: The vector search index schema.
"""
return VectorSearchIndexSchema(
name="vector_search_index",
primary_key=globals().get(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY),
text_column=globals().get(DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN),
doc_uri=globals().get(DATABRICKS_VECTOR_SEARCH_DOC_URI),
other_columns=globals().get(DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS),
if not globals().get(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None) or not globals().get(
DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None
):
return []

return [
VectorSearchIndexSchema(
name="vector_search_index",
primary_key=globals().get(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None),
text_column=globals().get(DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None),
doc_uri=globals().get(DATABRICKS_VECTOR_SEARCH_DOC_URI, None),
other_columns=globals().get(DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS, None),
)
]


def _clear_vector_search_schema():
"""
Clear the vector search schema defined by the user.
"""
globals().pop(DATABRICKS_VECTOR_SEARCH_PRIMARY_KEY, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_TEXT_COLUMN, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_DOC_URI, None)
globals().pop(DATABRICKS_VECTOR_SEARCH_OTHER_COLUMNS, None)


def _clear_dependencies_schema():
"""
Clear all the dependencies schema defined by the user.
"""
# Clear the vector search schema
_clear_vector_search_schema()


@contextmanager
def _get_dependencies_schema():
dependencies_schema = DependenciesSchemas(
vector_search_index_schemas=_get_vector_search_schema()
)
try:
yield dependencies_schema
finally:
_clear_dependencies_schema()


@dataclass
Expand Down Expand Up @@ -151,6 +187,9 @@ class DependenciesSchemas:
vector_search_index_schemas: List[VectorSearchIndexSchema] = field(default_factory=list)

def to_dict(self) -> Dict[str, Dict[ResourceType, List[Dict]]]:
if not self.vector_search_index_schemas:
return None

return {
"dependencies_schemas": {
ResourceType.VECTOR_SEARCH_INDEX.value: [
Expand Down
1 change: 1 addition & 0 deletions mlflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ def update_model_requirements(
__mlflow_model__ = None


@experimental
def set_model(model):
"""
When logging model as code, this function can be used to set the model object
Expand Down
8 changes: 7 additions & 1 deletion tests/langchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

from mlflow.models import ModelConfig, set_model
from mlflow.models import ModelConfig, set_model, set_vector_search_schema

base_config = ModelConfig(development_config="tests/langchain/config.yml")

Expand Down Expand Up @@ -74,3 +74,9 @@ def _llm_type(self) -> str:
)

set_model(retrieval_chain)
set_vector_search_schema(
primary_key="primary-key",
text_column="text-column",
doc_uri="doc-uri",
other_columns=["column1", "column2"],
)
14 changes: 13 additions & 1 deletion tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,17 @@ def test_save_load_chain_as_code(chain_model_signature):
assert reloaded_model.resources["databricks"] == {
"serving_endpoint": [{"name": "fake-endpoint"}]
}
assert reloaded_model.metadata["dependencies_schemas"] == {
"vector_search_index": [
{
"doc_uri": "doc-uri",
"name": "vector_search_index",
"other_columns": ["column1", "column2"],
"primary_key": "primary-key",
"text_column": "text-column",
}
]
}


@pytest.mark.skipif(
Expand Down Expand Up @@ -2565,7 +2576,7 @@ def test_save_load_chain_as_code_optional_code_path(chain_model_signature):
}
]
}
artifact_path = "model_path"
artifact_path = "new_model_path"
with mlflow.start_run() as run:
model_info = mlflow.langchain.log_model(
lc_model="tests/langchain/no_config/chain.py",
Expand Down Expand Up @@ -2607,6 +2618,7 @@ def test_save_load_chain_as_code_optional_code_path(chain_model_signature):
assert reloaded_model.resources["databricks"] == {
"serving_endpoint": [{"name": "fake-endpoint"}]
}
assert reloaded_model.metadata is None


def get_fake_chat_stream_model(endpoint_name="fake-stream-endpoint"):
Expand Down
33 changes: 22 additions & 11 deletions tests/models/test_dependencies_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mlflow.models.dependencies_schema import (
DependenciesSchemas,
VectorSearchIndexSchema,
_get_dependencies_schema,
_get_vector_search_schema,
set_vector_search_schema,
)
Expand Down Expand Up @@ -91,14 +92,24 @@ def test_set_vector_search_schema_creation():
doc_uri="doc-uri",
other_columns=["column1", "column2"],
)
assert _get_vector_search_schema().to_dict() == {
"vector_search_index": [
{
"name": "vector_search_index",
"primary_key": "primary-key",
"text_column": "text-column",
"doc_uri": "doc-uri",
"other_columns": ["column1", "column2"],
}
]
}
with _get_dependencies_schema() as schema:
assert schema.vector_search_index_schemas[0].to_dict() == {
"vector_search_index": [
{
"doc_uri": "doc-uri",
"name": "vector_search_index",
"other_columns": ["column1", "column2"],
"primary_key": "primary-key",
"text_column": "text-column",
}
]
}

with _get_dependencies_schema() as schema:
assert schema.to_dict() is None
assert _get_vector_search_schema() == []


def test_set_vector_search_schema_empty_creation():
with _get_dependencies_schema() as schema:
assert schema.to_dict() is None

0 comments on commit fa6e185

Please sign in to comment.