Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ChatModel (pyfunc subclass) #10820

Merged
merged 12 commits into from Feb 2, 2024
8 changes: 7 additions & 1 deletion docs/source/python_api/mlflow.pyfunc.rst
Expand Up @@ -18,8 +18,14 @@ mlflow.pyfunc
:members:
:undoc-members:

.. Include ``PythonModelContext``, which is imported from `mlflow.pyfunc.model`, in the
.. Include ``PythonModel``, which is imported from `mlflow.pyfunc.model`, in the
`mlflow.pyfunc` namespace
.. autoclass:: mlflow.pyfunc.PythonModel
:members:
:undoc-members:

.. Include ``ChatModel``, which is imported from `mlflow.pyfunc.model`, in the
`mlflow.pyfunc` namespace
.. autoclass:: mlflow.pyfunc.ChatModel
:members:
:undoc-members:
6 changes: 6 additions & 0 deletions docs/source/python_api/mlflow.types.rst
Expand Up @@ -4,3 +4,9 @@ mlflow.types
.. automodule:: mlflow.types
:members:
:show-inheritance:

.. automodule:: mlflow.types.llm
:members:

.. automodule:: mlflow.types.llm._BaseDataclass
:undoc-members:
36 changes: 36 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -227,6 +227,7 @@
import yaml

import mlflow
import mlflow.pyfunc.loaders
import mlflow.pyfunc.model
from mlflow.environment_variables import (
_MLFLOW_TESTING,
Expand Down Expand Up @@ -254,6 +255,7 @@
RESOURCE_DOES_NOT_EXIST,
)
from mlflow.pyfunc.model import (
ChatModel,
PythonModel,
PythonModelContext, # noqa: F401
_log_warning_if_params_not_in_predict_signature,
Expand All @@ -263,6 +265,14 @@
)
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.llm import (
CHAT_MODEL_INPUT_EXAMPLE,
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
ChatMessage,
ChatParams,
ChatResponse,
)
from mlflow.utils import (
PYTHON_VERSION,
_is_in_ipython_notebook,
Expand Down Expand Up @@ -1991,6 +2001,13 @@ def predict(model_input: List[str]) -> List[str]:

hints = None
if signature is not None:
if isinstance(python_model, ChatModel):
raise MlflowException(
"ChatModel subclasses have a standard signature that is set "
"automatically. Please remove the `signature` parameter from "
"the call to log_model() or save_model().",
error_code=INVALID_PARAMETER_VALUE,
)
mlflow_model.signature = signature
elif python_model is not None:
if callable(python_model):
Expand All @@ -1999,6 +2016,25 @@ def predict(model_input: List[str]) -> List[str]:
python_model, input_arg_index, input_example=input_example
):
mlflow_model.signature = signature
elif isinstance(python_model, ChatModel):
Copy link
Collaborator

@B-Step62 B-Step62 Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do any validation/warning if customer specifies custom signature with ChatModel? If it doesn't comply our pydantic schema, we may want to reject here rather than at runtime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes that's true, i'll throw a warning to say that the signature will be overridden and that it must conform to the spec

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woah actually this brought up a bug in my implementation—if the user specifies a signature, the model actually doesn't get saved as a ChatModel due to the elif in line 2005 above. i guess it's elif because this block contains a lot of validation/signature inference logic that we can skip if the user provides the signature themself. however, for ChatModel we always want to do these validations (e.g. output validation)

cc @B-Step62 what do you think about raising an exception when trying to save a ChatModel subclass with a signature, e.g:

if signature is not None:
  if isinstance(python_model, ChatModel):
    raise MlflowException("ChatModel subclasses specify a signature automatically, please remove the provided signature from the log_model() or save_model() call.")
  mlflow_model.singature = signature
elif python_model is not None:
  # no change from this PR

another way is making a separate block for ChatModels, e.g:

if isinstance(python_model, ChatModel):
  # move ChatModel logic to this block
  ...
elif signature is not None:
  # no change
  ...
elif python_model is not None:
  # no change
  ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice finding! I agree with throwing. Warning on happy path can be easily overlooked and almost invisible in automated environment.

mlflow_model.signature = ModelSignature(
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
)
input_example = CHAT_MODEL_INPUT_EXAMPLE

# perform output validation and throw if
# output is not coercable to ChatResponse
messages = [ChatMessage(**m) for m in input_example["messages"]]
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"})
output = python_model.predict(None, messages, params)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a problem to perform inference during saving? i saw we do it when trying to infer output signature, but since this is kind of an LLM-specific API, inference can be kind of expensive. the input example specifies max_tokens=10, so hopefully it isn't too bad.

if it is a concern, maybe we can just skip output validation entirely (as far as i can tell, there wouldn't be another way to ensure the return type of the predict() method is actually a ChatResponse).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are some risks:

  1. It may take a while (e.g. a few seconds) for the API request to finish.
  2. No guarantee that the LLM service is healthy. If OpenAI is down, this line would throw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 We shouldn't predict while saving the model, the error message would be confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from discussion offline, we'll keep the predict since we do it in transformers/other places already for output signature inference. i'll do some more testing here to make sure it's not a confusing experience

if not isinstance(output, ChatResponse):
raise MlflowException(
"Failed to save ChatModel. Please ensure that the model's predict() method "
"returns a ChatResponse object. If your predict() method currently returns "
"a dict, you can instantiate a ChatResponse by unpacking the output, e.g. "
"`ChatResponse(**output)`",
)
elif isinstance(python_model, PythonModel):
input_arg_index = 1 # second argument
if signature := _infer_signature_from_type_hints(
Expand Down
1 change: 1 addition & 0 deletions mlflow/pyfunc/loaders/__init__.py
@@ -0,0 +1 @@
import mlflow.pyfunc.loaders.chat_model # noqa: F401
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? or does python load all files in the subdirectory into the module by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python doesn't. If we want to do from mlflow.pyfunc.loaders import chat_model, we need this line, otherwise we don't.

69 changes: 69 additions & 0 deletions mlflow/pyfunc/loaders/chat_model.py
@@ -0,0 +1,69 @@
from typing import Any, Dict, Optional

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.pyfunc.model import (
_load_context_model_and_signature,
)
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse
from mlflow.utils.annotations import experimental


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks the same as PythonModel's _load_pyfunc function (except the wrapper it returned), could we reuse the function and extract the final class as a parameter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored the common part to _load_context_model_and_signature

context, chat_model, signature = _load_context_model_and_signature(model_path, model_config)
return _ChatModelPyfuncWrapper(chat_model=chat_model, context=context, signature=signature)


@experimental
class _ChatModelPyfuncWrapper:
"""
Wrapper class that converts dict inputs to pydantic objects accepted by :class:`~ChatModel`.
"""

def __init__(self, chat_model, context, signature):
"""
Args:
chat_model: An instance of a subclass of :class:`~ChatModel`.
context: A :class:`~PythonModelContext` instance containing artifacts that
``chat_model`` may use when performing inference.
signature: :class:`~ModelSignature` instance describing model input and output.
"""
self.chat_model = chat_model
self.context = context
self.signature = signature

def _convert_input(self, model_input):
# model_input should be correct from signature validation, so just convert it to dict here
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does to_dict accept orient param?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems so: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_dict.html#pandas.DataFrame.to_dict

but i'm kind of new to pandas—is there something else i should use?


messages = [ChatMessage(**message) for message in dict_input.pop("messages", [])]
params = ChatParams(**dict_input)

return messages, params

def predict(
self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Args:
model_input: Model input data in the form of a chat request.
params: Additional parameters to pass to the model for inference.
Unused in this implementation, as the params are handled
via ``self._convert_input()``.

Returns:
Model predictions in :py:class:`~ChatResponse` format.
"""
messages, params = self._convert_input(model_input)
response = self.chat_model.predict(self.context, messages, params)

if not isinstance(response, ChatResponse):
# shouldn't happen since there is validation at save time ensuring that
# the output is a ChatResponse, so raise an exception if it isn't
raise MlflowException(
"Model returned an invalid response. Expected a ChatResponse, but "
f"got {type(response)} instead.",
error_code=INTERNAL_ERROR,
)

return response.to_dict()
53 changes: 51 additions & 2 deletions mlflow/pyfunc/model.py
Expand Up @@ -21,6 +21,7 @@
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import _extract_type_hints
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse
from mlflow.utils.annotations import experimental
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
Expand Down Expand Up @@ -192,6 +193,40 @@ def model_config(self):
return self._model_config


@experimental
class ChatModel(PythonModel, metaclass=ABCMeta):
"""
A subclass of :class:`~PythonModel` that makes it more convenient to implement models
that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`,
users can create MLflow models with a ``predict()`` method that is more convenient
for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically
define input/output signatures and an input example, so manually specifying these values
when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary.

See the documentation of the ``predict()`` method below for details on that parameters and
outputs that are expected by the ``ChatModel`` API.
"""

@abstractmethod
def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
"""
Evaluates a chat input and produces a chat output.

Args:
messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
objects representing chat history.
params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
containing various parameters used to modify model behavior during
inference.

Returns:
A :py:class:`ChatResponse <mlflow.types.llm.ChatResponse>` object containing
the model's response(s), as well as other metadata.
"""


def _save_model_with_class_artifacts_params(
path,
python_model,
Expand Down Expand Up @@ -306,7 +341,7 @@ def _save_model_with_class_artifacts_params(

mlflow.pyfunc.add_to_model(
model=mlflow_model,
loader_module=__name__,
loader_module=_get_pyfunc_loader_module(python_model),
code=saved_code_subpath,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
Expand Down Expand Up @@ -351,7 +386,9 @@ def _save_model_with_class_artifacts_params(
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
def _load_context_model_and_signature(
model_path: str, model_config: Optional[Dict[str, Any]] = None
):
pyfunc_config = _get_flavor_configuration(
model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
)
Expand Down Expand Up @@ -391,6 +428,12 @@ def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None)
context = PythonModelContext(artifacts=artifacts, model_config=model_config)
python_model.load_context(context=context)
signature = mlflow.models.Model.load(model_path).signature

return context, python_model, signature


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
context, python_model, signature = _load_context_model_and_signature(model_path, model_config)
return _PythonModelPyfuncWrapper(
python_model=python_model, context=context, signature=signature
)
Expand Down Expand Up @@ -467,3 +510,9 @@ def predict(self, model_input, params: Optional[Dict[str, Any]] = None):
)
_log_warning_if_params_not_in_predict_signature(_logger, params)
return self.python_model.predict(self.context, self._convert_input(model_input))


def _get_pyfunc_loader_module(python_model):
if isinstance(python_model, ChatModel):
return mlflow.pyfunc.loaders.chat_model.__name__
return __name__
Comment on lines +515 to +518
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do we think of adding new pyfunc loaders to the mlflow.pyfunc.loaders module? i think it would be a clean way for us to implement future custom loaders (e.g. for RAGModel, CompletionModel).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

10 changes: 9 additions & 1 deletion mlflow/types/__init__.py
Expand Up @@ -3,6 +3,14 @@
components to describe interface independent of other frameworks or languages.
"""

import mlflow.types.llm # noqa: F401
from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec

__all__ = ["Schema", "ColSpec", "DataType", "TensorSpec", "ParamSchema", "ParamSpec"]
__all__ = [
"Schema",
"ColSpec",
"DataType",
"TensorSpec",
"ParamSchema",
"ParamSpec",
]