Skip to content

Commit

Permalink
Implement ChatModel (pyfunc subclass) (mlflow#10820)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: lu-wang-dl <lu.wang@databricks.com>
  • Loading branch information
daniellok-db authored and lu-wang-dl committed Feb 6, 2024
1 parent ffa3064 commit ed038cd
Show file tree
Hide file tree
Showing 10 changed files with 767 additions and 4 deletions.
8 changes: 7 additions & 1 deletion docs/source/python_api/mlflow.pyfunc.rst
Expand Up @@ -18,8 +18,14 @@ mlflow.pyfunc
:members:
:undoc-members:

.. Include ``PythonModelContext``, which is imported from `mlflow.pyfunc.model`, in the
.. Include ``PythonModel``, which is imported from `mlflow.pyfunc.model`, in the
`mlflow.pyfunc` namespace
.. autoclass:: mlflow.pyfunc.PythonModel
:members:
:undoc-members:

.. Include ``ChatModel``, which is imported from `mlflow.pyfunc.model`, in the
`mlflow.pyfunc` namespace
.. autoclass:: mlflow.pyfunc.ChatModel
:members:
:undoc-members:
6 changes: 6 additions & 0 deletions docs/source/python_api/mlflow.types.rst
Expand Up @@ -4,3 +4,9 @@ mlflow.types
.. automodule:: mlflow.types
:members:
:show-inheritance:

.. automodule:: mlflow.types.llm
:members:

.. automodule:: mlflow.types.llm._BaseDataclass
:undoc-members:
36 changes: 36 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -227,6 +227,7 @@
import yaml

import mlflow
import mlflow.pyfunc.loaders
import mlflow.pyfunc.model
from mlflow.environment_variables import (
_MLFLOW_TESTING,
Expand Down Expand Up @@ -254,6 +255,7 @@
RESOURCE_DOES_NOT_EXIST,
)
from mlflow.pyfunc.model import (
ChatModel,
PythonModel,
PythonModelContext, # noqa: F401
_log_warning_if_params_not_in_predict_signature,
Expand All @@ -263,6 +265,14 @@
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.llm import (
CHAT_MODEL_INPUT_EXAMPLE,
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
ChatMessage,
ChatParams,
ChatResponse,
)
from mlflow.utils import (
PYTHON_VERSION,
_is_in_ipython_notebook,
Expand Down Expand Up @@ -1991,6 +2001,13 @@ def predict(model_input: List[str]) -> List[str]:

hints = None
if signature is not None:
if isinstance(python_model, ChatModel):
raise MlflowException(
"ChatModel subclasses have a standard signature that is set "
"automatically. Please remove the `signature` parameter from "
"the call to log_model() or save_model().",
error_code=INVALID_PARAMETER_VALUE,
)
mlflow_model.signature = signature
elif python_model is not None:
if callable(python_model):
Expand All @@ -1999,6 +2016,25 @@ def predict(model_input: List[str]) -> List[str]:
python_model, input_arg_index, input_example=input_example
):
mlflow_model.signature = signature
elif isinstance(python_model, ChatModel):
mlflow_model.signature = ModelSignature(
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
)
input_example = CHAT_MODEL_INPUT_EXAMPLE

# perform output validation and throw if
# output is not coercable to ChatResponse
messages = [ChatMessage(**m) for m in input_example["messages"]]
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"})
output = python_model.predict(None, messages, params)
if not isinstance(output, ChatResponse):
raise MlflowException(
"Failed to save ChatModel. Please ensure that the model's predict() method "
"returns a ChatResponse object. If your predict() method currently returns "
"a dict, you can instantiate a ChatResponse by unpacking the output, e.g. "
"`ChatResponse(**output)`",
)
elif isinstance(python_model, PythonModel):
input_arg_index = 1 # second argument
if signature := _infer_signature_from_type_hints(
Expand Down
1 change: 1 addition & 0 deletions mlflow/pyfunc/loaders/__init__.py
@@ -0,0 +1 @@
import mlflow.pyfunc.loaders.chat_model # noqa: F401
69 changes: 69 additions & 0 deletions mlflow/pyfunc/loaders/chat_model.py
@@ -0,0 +1,69 @@
from typing import Any, Dict, Optional

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.pyfunc.model import (
_load_context_model_and_signature,
)
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse
from mlflow.utils.annotations import experimental


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
context, chat_model, signature = _load_context_model_and_signature(model_path, model_config)
return _ChatModelPyfuncWrapper(chat_model=chat_model, context=context, signature=signature)


@experimental
class _ChatModelPyfuncWrapper:
"""
Wrapper class that converts dict inputs to pydantic objects accepted by :class:`~ChatModel`.
"""

def __init__(self, chat_model, context, signature):
"""
Args:
chat_model: An instance of a subclass of :class:`~ChatModel`.
context: A :class:`~PythonModelContext` instance containing artifacts that
``chat_model`` may use when performing inference.
signature: :class:`~ModelSignature` instance describing model input and output.
"""
self.chat_model = chat_model
self.context = context
self.signature = signature

def _convert_input(self, model_input):
# model_input should be correct from signature validation, so just convert it to dict here
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()}

messages = [ChatMessage(**message) for message in dict_input.pop("messages", [])]
params = ChatParams(**dict_input)

return messages, params

def predict(
self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Args:
model_input: Model input data in the form of a chat request.
params: Additional parameters to pass to the model for inference.
Unused in this implementation, as the params are handled
via ``self._convert_input()``.
Returns:
Model predictions in :py:class:`~ChatResponse` format.
"""
messages, params = self._convert_input(model_input)
response = self.chat_model.predict(self.context, messages, params)

if not isinstance(response, ChatResponse):
# shouldn't happen since there is validation at save time ensuring that
# the output is a ChatResponse, so raise an exception if it isn't
raise MlflowException(
"Model returned an invalid response. Expected a ChatResponse, but "
f"got {type(response)} instead.",
error_code=INTERNAL_ERROR,
)

return response.to_dict()
53 changes: 51 additions & 2 deletions mlflow/pyfunc/model.py
Expand Up @@ -21,6 +21,7 @@
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import _extract_type_hints
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse
from mlflow.utils.annotations import experimental
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
Expand Down Expand Up @@ -192,6 +193,40 @@ def model_config(self):
return self._model_config


@experimental
class ChatModel(PythonModel, metaclass=ABCMeta):
"""
A subclass of :class:`~PythonModel` that makes it more convenient to implement models
that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`,
users can create MLflow models with a ``predict()`` method that is more convenient
for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically
define input/output signatures and an input example, so manually specifying these values
when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary.
See the documentation of the ``predict()`` method below for details on that parameters and
outputs that are expected by the ``ChatModel`` API.
"""

@abstractmethod
def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
"""
Evaluates a chat input and produces a chat output.
Args:
messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
objects representing chat history.
params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
containing various parameters used to modify model behavior during
inference.
Returns:
A :py:class:`ChatResponse <mlflow.types.llm.ChatResponse>` object containing
the model's response(s), as well as other metadata.
"""


def _save_model_with_class_artifacts_params(
path,
python_model,
Expand Down Expand Up @@ -306,7 +341,7 @@ def _save_model_with_class_artifacts_params(

mlflow.pyfunc.add_to_model(
model=mlflow_model,
loader_module=__name__,
loader_module=_get_pyfunc_loader_module(python_model),
code=saved_code_subpath,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
Expand Down Expand Up @@ -351,7 +386,9 @@ def _save_model_with_class_artifacts_params(
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
def _load_context_model_and_signature(
model_path: str, model_config: Optional[Dict[str, Any]] = None
):
pyfunc_config = _get_flavor_configuration(
model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
)
Expand Down Expand Up @@ -391,6 +428,12 @@ def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None)
context = PythonModelContext(artifacts=artifacts, model_config=model_config)
python_model.load_context(context=context)
signature = mlflow.models.Model.load(model_path).signature

return context, python_model, signature


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
context, python_model, signature = _load_context_model_and_signature(model_path, model_config)
return _PythonModelPyfuncWrapper(
python_model=python_model, context=context, signature=signature
)
Expand Down Expand Up @@ -467,3 +510,9 @@ def predict(self, model_input, params: Optional[Dict[str, Any]] = None):
)
_log_warning_if_params_not_in_predict_signature(_logger, params)
return self.python_model.predict(self.context, self._convert_input(model_input))


def _get_pyfunc_loader_module(python_model):
if isinstance(python_model, ChatModel):
return mlflow.pyfunc.loaders.chat_model.__name__
return __name__
10 changes: 9 additions & 1 deletion mlflow/types/__init__.py
Expand Up @@ -3,6 +3,14 @@
components to describe interface independent of other frameworks or languages.
"""

import mlflow.types.llm # noqa: F401
from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec

__all__ = ["Schema", "ColSpec", "DataType", "TensorSpec", "ParamSchema", "ParamSpec"]
__all__ = [
"Schema",
"ColSpec",
"DataType",
"TensorSpec",
"ParamSchema",
"ParamSpec",
]

0 comments on commit ed038cd

Please sign in to comment.