Skip to content

Commit

Permalink
Fix params and model_config handling for llm/v1/xxx Transformers model (
Browse files Browse the repository at this point in the history
#12401)

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Jun 19, 2024
1 parent 46a62eb commit 9f1adbc
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 79 deletions.
23 changes: 11 additions & 12 deletions mlflow/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@
_METADATA_LLM_INFERENCE_TASK_KEY,
_SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK,
_get_default_task_for_llm_inference_task,
convert_data_messages_with_chat_template,
convert_messages_to_prompt,
infer_signature_from_llm_inference_task,
postprocess_output_for_llm_inference_task,
postprocess_output_for_llm_v1_embedding_task,
preprocess_llm_embedding_params,
preprocess_llm_inference_params,
preprocess_llm_inference_input,
)
from mlflow.transformers.model_io import (
_COMPONENTS_BINARY_DIR_NAME,
Expand Down Expand Up @@ -1655,20 +1655,19 @@ def predict(self, data, params: Optional[Dict[str, Any]] = None):
Returns:
Model predictions.
"""
if self.llm_inference_task == _LLM_INFERENCE_TASK_CHAT:
convert_data_messages_with_chat_template(data, self.pipeline.tokenizer)
data, params = preprocess_llm_inference_params(data, self.flavor_config)
elif self.llm_inference_task == _LLM_INFERENCE_TASK_COMPLETIONS:
data, params = preprocess_llm_inference_params(data, self.flavor_config)
elif self.llm_inference_task == _LLM_INFERENCE_TASK_EMBEDDING:
data, params = preprocess_llm_embedding_params(data)

# NB: This `predict` method updates the model_config several times. To make the predict
# call idempotent, we keep the original self.model_config immutable and creates a deep
# copy of it at every predict call.
model_config = copy.deepcopy(dict(self.model_config))
params = self._merge_model_config_with_params(model_config, params)

model_config = self._merge_model_config_with_params(model_config, params)
if self.llm_inference_task == _LLM_INFERENCE_TASK_CHAT:
data, params = preprocess_llm_inference_input(data, params, self.flavor_config)
data = [convert_messages_to_prompt(msgs, self.pipeline.tokenizer) for msgs in data]
elif self.llm_inference_task == _LLM_INFERENCE_TASK_COMPLETIONS:
data, params = preprocess_llm_inference_input(data, params, self.flavor_config)
elif self.llm_inference_task == _LLM_INFERENCE_TASK_EMBEDDING:
data, params = preprocess_llm_embedding_params(data)

if isinstance(data, pd.DataFrame):
input_data = self._convert_pandas_to_dict(data)
Expand Down Expand Up @@ -1700,7 +1699,7 @@ def predict(self, data, params: Optional[Dict[str, Any]] = None):
_validate_input_dictionary_contains_only_strings_and_lists_of_strings(x)
for x in input_data
)
return self._predict(input_data, model_config)
return self._predict(input_data, params)

def _predict(self, data, model_config):
import transformers
Expand Down
89 changes: 55 additions & 34 deletions mlflow/transformers/llm_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

from mlflow.exceptions import MlflowException
from mlflow.models import ModelSignature
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE
from mlflow.transformers.flavor_config import FlavorKey
from mlflow.types.llm import (
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
Expand Down Expand Up @@ -55,6 +58,11 @@
),
}

_LLM_INFERENCE_TASK_TO_DATA_FIELD = {
_LLM_INFERENCE_TASK_CHAT: "messages",
_LLM_INFERENCE_TASK_COMPLETIONS: "prompt",
}


def infer_signature_from_llm_inference_task(
inference_task: str, signature: Optional[ModelSignature] = None
Expand All @@ -73,25 +81,31 @@ def infer_signature_from_llm_inference_task(
return inferred_signature


def convert_data_messages_with_chat_template(data, tokenizer):
"""For the Chat inference task, apply chat template to messages to create prompt."""
if "messages" in data.columns:
messages = data.pop("messages").tolist()[0]
else:
raise MlflowException("The 'messages' field is required for the Chat inference task.")
def convert_messages_to_prompt(messages: List[Dict], tokenizer) -> str:
"""For the Chat inference task, apply chat template to messages to create prompt.
try:
messages_str = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
Args:
messages: List of message e.g. [{"role": user, "content": xxx}, ...]
tokenizer: The tokenizer object used for inference.
Returns:
The prompt string contains the messages.
"""
if not (isinstance(messages, list) and all(isinstance(msg, dict) for msg in messages)):
raise MlflowException(
f"Input messages should be list of dictionaries, but got: {type(messages)}.",
error_code=INVALID_PARAMETER_VALUE,
)

try:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception as e:
raise MlflowException(f"Failed to apply chat template: {e}")

data["prompt"] = messages_str


def preprocess_llm_inference_params(
data,
def preprocess_llm_inference_input(
data: pd.DataFrame,
params: Optional[Dict[str, Any]] = None,
flavor_config: Optional[Dict[str, Any]] = None,
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Expand All @@ -106,27 +120,34 @@ def preprocess_llm_inference_params(
"`data` is expected to be a pandas DataFrame for MLflow inference task after signature "
f"enforcement, but got type: {type(data)}."
)

updated_data = []
params = {}

for column in data.columns:
if column in ["prompt", "messages"]:
updated_data = data[column].tolist()
else:
param = data[column].tolist()[0]
if column == "max_tokens":
params["max_new_tokens"] = param
elif column == "stop":
source_model_name = (
flavor_config.get("source_model_name") if flavor_config else None
)
if stop := _get_stopping_criteria(param, source_model_name):
params["stopping_criteria"] = stop
else:
params[column] = param

return updated_data, params
flavor_config = flavor_config or {}
params = params or {}
# Pandas convert None to np.nan internally, which is not preferred
data = data.replace(np.nan, None).to_dict(orient="list")

# Extract list of input data (prompt, messages) to LLM
task = flavor_config[_LLM_INFERENCE_TASK_KEY]
input_col = _LLM_INFERENCE_TASK_TO_DATA_FIELD.get(task)
if input_col not in data:
raise MlflowException(
f"Transformer model saved with `{task}` task excepts `{input_col}`"
"to be passed as input data.",
error_code=BAD_REQUEST,
)
update_data = data.pop(input_col)

# The rest of fields in input payload should goes to params and override default ones
params_in_data = {k: v[0] for k, v in data.items() if v[0] is not None}
params = {**params, **params_in_data}

if max_tokens := params.pop("max_tokens", None):
params["max_new_tokens"] = max_tokens
if stop := params.pop("stop", None):
params["stopping_criteria"] = _get_stopping_criteria(
stop,
flavor_config.get(FlavorKey.MODEL_NAME),
)
return update_data, params


def _get_stopping_criteria(stop: Optional[Union[str, List[str]]], model_name: Optional[str] = None):
Expand Down
153 changes: 124 additions & 29 deletions tests/transformers/test_transformers_llm_inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from collections import namedtuple
from typing import Dict, List
from unittest import mock

Expand All @@ -14,9 +15,9 @@
_get_output_and_usage_from_tensor,
_get_stopping_criteria,
_get_token_usage,
convert_data_messages_with_chat_template,
convert_messages_to_prompt,
infer_signature_from_llm_inference_task,
preprocess_llm_inference_params,
preprocess_llm_inference_input,
)
from mlflow.types.llm import (
CHAT_MODEL_INPUT_SCHEMA,
Expand Down Expand Up @@ -61,40 +62,134 @@ def apply_chat_template(self, messages: List[Dict[str, str]], **kwargs):


def test_apply_chat_template():
tokenizer = DummyTokenizer()

data1 = pd.DataFrame(
{
"messages": pd.Series(
[[{"role": "A", "content": "one"}, {"role": "B", "content": "two"}]]
),
"random": ["value"],
}
)

data1 = [{"role": "A", "content": "one"}, {"role": "B", "content": "two"}]
# Test that the function modifies the data in place for Chat task
convert_data_messages_with_chat_template(data1, tokenizer)
prompt = convert_messages_to_prompt(data1, DummyTokenizer())
assert prompt == "one two"

expected_data = pd.DataFrame({"random": ["value"], "prompt": ["one two"]})
pd.testing.assert_frame_equal(data1, expected_data)
with pytest.raises(MlflowException, match=r"Input messages should be list of"):
convert_messages_to_prompt([["one", "two"]], DummyTokenizer())


def test_preprocess_llm_inference_params():
data = pd.DataFrame(
{
"prompt": ["Hello world!"],
"temperature": [0.7],
"max_tokens": [100],
# do not pass this to params as it is None
"stop": None,
}
)
_TestCase = namedtuple("_TestCase", ["data", "params", "expected_data", "expected_params"])

data, params = preprocess_llm_inference_params(data, flavor_config=None)

@pytest.mark.parametrize(
"case",
[
# Case 0: Data only includes prompt
_TestCase(
data={"prompt": ["Hello world!"]},
params={},
expected_data=["Hello world!"],
expected_params={},
),
# Case 1: Data includes prompt and params
_TestCase(
data={
"prompt": ["Hello world!"],
"temperature": [0.7],
"max_tokens": [100],
"stop": None,
},
params={},
expected_data=["Hello world!"],
expected_params={
"temperature": 0.7,
# max_tokens is replaced with max_new_tokens
"max_new_tokens": 100,
# do not pass `stop` to params as it is None
},
),
# Case 2: Params are passed if not specified in data
_TestCase(
data={
"prompt": ["Hello world!"],
},
params={
"temperature": 0.7,
"max_tokens": 100,
"stop": ["foo", "bar"],
},
expected_data=["Hello world!"],
expected_params={
"temperature": 0.7,
"max_new_tokens": 100,
# Stopping criteria is _StopSequenceMatchCriteria instance
# "stop": ...
},
),
# Case 3: Data overrides params
_TestCase(
data={
"messages": ["Hello world!"],
"temperature": [0.1],
"max_tokens": [100],
"stop": [["foo", "bar"]],
},
params={
"temperature": [0.2],
"max_tokens": [200],
"stop": ["foo", "bar", "baz"],
},
expected_data=["Hello world!"],
expected_params={
"temperature": 0.1,
"max_new_tokens": 100,
},
),
# Case 4: Batch input
_TestCase(
data={
"prompt": ["Hello!", "Hi", "Hola"],
"temperature": [0.1, 0.2, 0.3],
"max_tokens": [None, 200, 300],
},
params={
"temperature": 0.4,
"max_tokens": 400,
},
expected_data=["Hello!", "Hi", "Hola"],
# The values in the first data is used, otherwise params
expected_params={
"temperature": 0.1,
"max_new_tokens": 400,
},
),
],
)
def test_preprocess_llm_inference_input(case):
data = pd.DataFrame(case.data)

task = "llm/v1/completions" if "prompt" in case.data else "llm/v1/chat"
flavor_config = {"inference_task": task, "source_model_name": "test"}

with mock.patch(
"mlflow.transformers.llm_inference_utils._get_stopping_criteria"
) as mock_get_stopping_criteria:
data, params = preprocess_llm_inference_input(data, case.params, flavor_config)

# Test that OpenAI params are separated from data and replaced with Hugging Face params
assert data == ["Hello world!"]
assert params == {"max_new_tokens": 100, "temperature": 0.7}
assert data == case.expected_data
if "stopping_criteria" in params:
assert params.pop("stopping_criteria") is not None
mock_get_stopping_criteria.assert_called_once_with(["foo", "bar"], "test")
assert params == case.expected_params


def test_preprocess_llm_inference_input_raise_if_key_invalid():
# Missing input key
with pytest.raises(MlflowException, match=r"Transformer model saved with"):
preprocess_llm_inference_input(
pd.DataFrame({"invalid_key": [1, 2, 3]}),
flavor_config={"inference_task": "llm/v1/completions"},
)

# Unmatched key (should be "messages" for chat task)
with pytest.raises(MlflowException, match=r"Transformer model saved with"):
preprocess_llm_inference_input(
pd.DataFrame({"prompt": ["Hi"]}), flavor_config={"inference_task": "llm/v1/chat"}
)


@mock.patch("transformers.AutoTokenizer.from_pretrained")
Expand Down
Loading

0 comments on commit 9f1adbc

Please sign in to comment.