Skip to content

Commit

Permalink
Support file paths for model_config in langchain (#11843)
Browse files Browse the repository at this point in the history
Signed-off-by: Ann Zhang <ann.zhang@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
  • Loading branch information
2 people authored and BenWilson2 committed May 6, 2024
1 parent 73d098b commit 39b6b1b
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 86 deletions.
85 changes: 62 additions & 23 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.model_config import _set_model_config
from mlflow.models.signature import _infer_signature_from_input_example
from mlflow.models.utils import _save_example
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
Expand Down Expand Up @@ -85,7 +86,7 @@
_add_code_from_conf_to_system_path,
_get_flavor_configuration,
_validate_and_copy_code_paths,
_validate_and_copy_model_code_path,
_validate_and_copy_model_code_and_config_paths,
_validate_and_prepare_target_save_path,
)
from mlflow.utils.requirements_utils import _get_pinned_requirement
Expand All @@ -94,6 +95,8 @@

FLAVOR_NAME = "langchain"
_MODEL_TYPE_KEY = "model_type"
_MODEL_CODE_CONFIG = "model_config"
_MODEL_CODE_PATH = "model_code_path"


def get_default_pip_requirements():
Expand Down Expand Up @@ -133,6 +136,7 @@ def save_model(
loader_fn=None,
persist_dir=None,
example_no_conversion=False,
model_config=None,
):
"""
Save a LangChain model to a path on the local file system.
Expand Down Expand Up @@ -228,6 +232,11 @@ def load_retriever(persist_directory):
See a complete example in examples/langchain/retrieval_qa_chain.py.
example_no_conversion: {{ example_no_conversion }}
model_config: The model configuration to apply to the model if saving model as code. This
configuration is available during model loading.
.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.
"""
import langchain
from langchain.schema import BaseRetriever
Expand All @@ -238,6 +247,8 @@ def load_retriever(persist_directory):

path = os.path.abspath(path)
_validate_and_prepare_target_save_path(path)

model_config_path = None
model_code_path = None
if isinstance(lc_model, str):
# The LangChain model is defined as Python code located in the file at the path
Expand All @@ -252,17 +263,26 @@ def load_retriever(persist_directory):
"file path or a databricks notebook file path containing the code for defining "
"the chain instance."
)
if code_paths and len(code_paths) > 1:
raise mlflow.MlflowException.invalid_parameter_value(
"When the model is a string, and if the code_paths are specified, "
"it should contain only one path."
"This config path is used to set config.yml file path "
"for the model. This path should be passed in via the code_paths. "
f"Current code paths: {code_paths}"
)

if isinstance(model_config, str):
if os.path.exists(model_config):
model_config_path = model_config
else:
raise mlflow.MlflowException.invalid_parameter_value(
f"If the provided model_config '{model_config}' is a string, it must be a "
"valid yaml file path containing the configuration for the model."
)
# TODO: deal with dicts properly as well

if not model_config:
# If the model_config is not provided we fallback to getting the config path
# from code_paths so that is backwards compatible.
if code_paths and len(code_paths) == 1 and os.path.exists(code_paths[0]):
model_config_path = code_paths[0]

_validate_and_copy_model_code_and_config_paths(lc_model, model_config_path, path)

code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
model_code_dir_subpath = _validate_and_copy_model_code_path(model_code_path, path)

if signature is None:
if input_example is not None:
Expand Down Expand Up @@ -317,19 +337,19 @@ def load_retriever(persist_directory):
**model_data_kwargs,
}
else:
# TODO: use model_config instead
# If the model is a string, we expect the code_path which is ideally config.yml
# would be used in the model. We set the code_path here so it can be set
# globally when the model is loaded with the local path. So the consumer
# can use that path instead of the config.yml path when the model is loaded
# TODO: what if model_config is not a string / file path?
flavor_conf = (
{_CODE_CONFIG: code_paths[0], _CODE_PATH: lc_model}
if code_paths and len(code_paths) >= 1
else {_CODE_CONFIG: None, _CODE_PATH: lc_model}
{_MODEL_CODE_CONFIG: model_config_path, _MODEL_CODE_PATH: lc_model}
if model_config_path
else {_MODEL_CODE_CONFIG: None, _MODEL_CODE_PATH: lc_model}
)
model_data_kwargs = {}

# TODO: pass model_config
# TODO: Pass file paths for model_config when it is supported in pyfunc
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.langchain",
Expand All @@ -338,17 +358,17 @@ def load_retriever(persist_directory):
code=code_dir_subpath,
predict_stream_fn="predict_stream",
streamable=streamable,
model_code=model_code_dir_subpath,
model_code_path=model_code_path,
model_config=None if isinstance(model_config, str) else model_config,
**model_data_kwargs,
)

if Version(langchain.__version__) >= Version("0.0.311"):
checker_model = lc_model
if isinstance(lc_model, str):
# TODO: use model_config instead of code_paths[0]
checker_model = (
_load_model_code_path(lc_model, code_paths[0])
if code_paths and len(code_paths) >= 1
_load_model_code_path(lc_model, model_config_path)
if model_config_path
else _load_model_code_path(lc_model)
)

Expand All @@ -359,7 +379,6 @@ def load_retriever(persist_directory):
FLAVOR_NAME,
langchain_version=langchain.__version__,
code=code_dir_subpath,
model_code=model_code_dir_subpath,
streamable=streamable,
**flavor_conf,
)
Expand Down Expand Up @@ -411,6 +430,7 @@ def log_model(
persist_dir=None,
example_no_conversion=False,
run_id=None,
model_config=None,
):
"""
Log a LangChain model as an MLflow artifact for the current run.
Expand Down Expand Up @@ -517,6 +537,11 @@ def load_retriever(persist_directory):
run_id: run_id to associate with this model version. If specified, we resume the
run and log the model to that run. Otherwise, a new run is created.
Default to None.
model_config: The model configuration to apply to the model if saving model as code. This
configuration is available during model loading.
.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.
Returns:
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
Expand All @@ -541,6 +566,7 @@ def load_retriever(persist_directory):
persist_dir=persist_dir,
example_no_conversion=example_no_conversion,
run_id=run_id,
model_config=model_config,
)


Expand Down Expand Up @@ -722,7 +748,9 @@ def predict_stream(
Returns:
An iterator of model prediction chunks.
"""
from mlflow.langchain.api_request_parallel_processor import process_stream_request
from mlflow.langchain.api_request_parallel_processor import (
process_stream_request,
)

data = self._prepare_predict_stream_messages(data)
return process_stream_request(
Expand Down Expand Up @@ -751,7 +779,9 @@ def _predict_stream_with_callbacks(
Returns:
An iterator of model prediction chunks.
"""
from mlflow.langchain.api_request_parallel_processor import process_stream_request
from mlflow.langchain.api_request_parallel_processor import (
process_stream_request,
)

data = self._prepare_predict_stream_messages(data)
return process_stream_request(
Expand Down Expand Up @@ -840,7 +870,12 @@ def _load_pyfunc(path):

def _load_model_from_local_fs(local_model_path):
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
if _CODE_CONFIG in flavor_conf:
if _MODEL_CODE_PATH in flavor_conf:
code_path = flavor_conf.get(_MODEL_CODE_PATH)
config_path = flavor_conf.get(_MODEL_CODE_CONFIG, None)
return _load_model_code_path(code_path, config_path)
# Code for backwards compatibility, relies on RAG utils - remove in the future
elif _CODE_CONFIG in flavor_conf:
path = flavor_conf.get(_CODE_CONFIG)
flavor_code_config = flavor_conf.get(FLAVOR_CONFIG_CODE)
if path is not None:
Expand Down Expand Up @@ -896,10 +931,14 @@ def load_model(model_uri, dst_path=None):

@contextmanager
def _config_path_context(code_path: Optional[str] = None):
_set_model_config(code_path)
# set rag utils global for backwards compatibility
_set_config_path(code_path)
try:
yield
finally:
_set_model_config(None)
# unset rag utils global for backwards compatibility
_set_config_path(None)


Expand Down
16 changes: 10 additions & 6 deletions mlflow/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

import yaml

import mlflow
__mlflow_model_config__ = None


# TODO: Let ModelConfig take in a dictionary instead of a file path
class ModelConfig:
"""
ModelConfig used in code to read a YAML configuration file, and this configuration file can be
overridden when logging a model.
"""

def __init__(self, *, development_config: Optional[str] = None):
# TODO: Update global path after we pass in paths using model_config
_mlflow_rag_config_path = getattr(
mlflow.langchain._rag_utils, "__databricks_rag_config_path__", None
)
self.config_path = _mlflow_rag_config_path or development_config
config = globals().get("__mlflow_model_config__", None)
# backwards compatibility
rag_config = globals().get("__databricks_rag_config_path__", None)
self.config_path = config or rag_config or development_config

if not self.config_path:
raise FileNotFoundError("Config file is None. Please provide a valid path.")
Expand Down Expand Up @@ -50,3 +50,7 @@ def get(self, key):
return config_data[key]
else:
raise KeyError(f"Key '{key}' not found in configuration: {config_data}.")


def _set_model_config(model_config):
globals()["__mlflow_model_config__"] = model_config
39 changes: 33 additions & 6 deletions mlflow/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,45 @@ def _validate_and_copy_code_paths(code_paths, path, default_subpath="code"):
return code_dir_subpath


def _validate_and_copy_model_code_path(code_path, path, default_subpath="model_code"):
def _validate_path_exists(path, name):
if path and not os.path.exists(path):
raise MlflowException(
message=(
f"Failed to copy the specified {name} path '{path}' into the model "
f"artifacts. The specified {name }path does not exist. Please specify a valid "
f"{name} path and try again."
),
error_code=INVALID_PARAMETER_VALUE,
)


def _validate_and_copy_model_code_and_config_paths(code_path, config_path, path):
"""Copies the model code from code_path to a directory.
Args:
code_path: A file containing model code that should be logged as an artifact.
config_path: A file containing model config code that should be logged as an artifact.
path: The local model path.
default_subpath: The default directory name used to store model code artifacts.
"""
if code_path:
return _validate_and_copy_code_paths([code_path], path, default_subpath)
else:
return None
_validate_path_exists(code_path, "code")
_validate_path_exists(config_path, "config")
try:
_copy_file_or_tree(src=code_path, dst=path)
if config_path:
_copy_file_or_tree(src=config_path, dst=path)
except OSError as e:
# A common error is code-paths includes Databricks Notebook. We include it in error
# message when running in Databricks, but not in other envs tp avoid confusion.
example = ", such as Databricks Notebooks" if is_in_databricks_runtime() else ""
raise MlflowException(
message=(
f"Failed to copy the specified code path '{code_path}' into the model "
"artifacts. It appears that your code path includes file(s) that cannot "
f"be copied{example}. Please specify a code path that does not include "
"such files and try again.",
),
error_code=INVALID_PARAMETER_VALUE,
) from e


def _add_code_to_system_path(code_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/langchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _llm_type(self) -> str:
return FakeChatModel(endpoint=endpoint)


config_path = mlflow.langchain._rag_utils.__databricks_rag_config_path__
config_path = mlflow.models.model_config.__mlflow_model_config__
assert os.path.exists(config_path)

with open(config_path) as f:
Expand Down

0 comments on commit 39b6b1b

Please sign in to comment.