Skip to content

Commit

Permalink
Chat model predict_stream support (#12128)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed May 27, 2024
1 parent 24872d1 commit e1e0047
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
24 changes: 20 additions & 4 deletions mlflow/pyfunc/loaders/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterator, Optional

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
Expand Down Expand Up @@ -68,14 +68,30 @@ def predict(
"""
messages, params = self._convert_input(model_input)
response = self.chat_model.predict(self.context, messages, params)
return self._response_to_dict(response)

def _response_to_dict(self, response: ChatResponse) -> Dict[str, Any]:
if not isinstance(response, ChatResponse):
# shouldn't happen since there is validation at save time ensuring that
# the output is a ChatResponse, so raise an exception if it isn't
raise MlflowException(
"Model returned an invalid response. Expected a ChatResponse, but "
f"got {type(response)} instead.",
error_code=INTERNAL_ERROR,
)

return response.to_dict()

def predict_stream(
self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None
) -> Iterator[Dict[str, Any]]:
"""
Args:
model_input: Model input data in the form of a chat request.
params: Additional parameters to pass to the model for inference.
Unused in this implementation, as the params are handled
via ``self._convert_input()``.
Returns:
Iterator over model predictions in :py:class:`~ChatResponse` format.
"""
messages, params = self._convert_input(model_input)
response = self.chat_model.predict_stream(self.context, messages, params)
return map(self._response_to_dict, response)
25 changes: 24 additions & 1 deletion mlflow/pyfunc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterator, List, Optional

import cloudpickle
import yaml
Expand Down Expand Up @@ -243,6 +243,29 @@ def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> C
the model's response(s), as well as other metadata.
"""

def predict_stream(
self, context, messages: List[ChatMessage], params: ChatParams
) -> Iterator[ChatResponse]:
"""
Evaluates a chat input and produces a chat output.
Overrides this function to implement a real stream prediction.
By default, this function just yields result of `predict` function.
Args:
messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
objects representing chat history.
params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
containing various parameters used to modify model behavior during
inference.
Returns:
An iterator over :py:class:`ChatResponse <mlflow.types.llm.ChatResponse>` object
containing the model's response(s), as well as other metadata.
"""
yield self.predict(context, messages, params)


def _save_model_with_class_artifacts_params(
path,
Expand Down
14 changes: 14 additions & 0 deletions tests/pyfunc/test_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,17 @@ def test_chat_model_works_with_infer_signature_multi_input_example(tmp_path):
**DEFAULT_PARAMS,
**params_subset,
}


def test_chat_model_predict_stream(tmp_path):
model = TestChatModel()
mlflow.pyfunc.save_model(python_model=model, path=tmp_path)

loaded_model = mlflow.pyfunc.load_model(tmp_path)
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello!"},
]

response = next(loaded_model.predict_stream({"messages": messages}))
assert response["choices"][0]["message"]["content"] == json.dumps(messages)

0 comments on commit e1e0047

Please sign in to comment.