# Custom Chat Models with MLflow

## Resources
- [MLflow PyFunc Chat Model Tutorial](https://mlflow.org/docs/latest/llms/transformers/tutorials/conversational/pyfunc-chat-model.html)
- [MLflow PyFunc ChatModel API](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatModel)

## Outline
- What makes a chat model different
- Why defining a Chat Model with Pyfunc is kind of hard
- ChatModel class makes it easier
  - Show equivalent ChatModel and PyFunc model definition
- Add-on: retriever in ChatModel
- Add-on: streaming

## Chat Models with MLflow

OpenAI's model input and output schemas have become the defacto standards among LLM providers. Compatibility with the OpenAI API spec enables models to integrate seamlessly with many different AI tools, evaluation systems, UIs, and more.

This post shows how to make [custom PyFunc models](https://mlflow.org/blog/custom-pyfunc) that conform to the OpenAI API spec.

## The ChatModel Class

As of MLflow 2.11, you can use the `ChatModel` class to define custom chat models. `ChatModel` is a subclass of `PythonModel` that automatically defines input/output signatures that are compatible with the OpenAI API spec. Let's try a simple example and wrap Google's Gemma 2B model with `ChatModel`:

In [1]:
from transformers import pipeline

gemma = pipeline("text-generation", model="google/gemma-2b-it",
                 device_map="auto")

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
import mlflow

mlflow.set_experiment("ChatModel")

with mlflow.start_run() as run:
    model_info = mlflow.transformers.log_model(
        artifact_path="gemma-text-generation", transformers_model=gemma,
        task="llm/v1/chat"
    )



In [13]:
import mlflow

# Load the previously saved MLflow model
model = mlflow.pyfunc.load_model(model_info.model_uri)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [24]:
messages = [{"role": "user", "content": "Tell me a short joke about AI."}]
model.predict({"messages": messages, "max_tokens": 25})

[{'id': 'a5aa96c2-2ddb-4c2b-a199-0a0a59fde9d8',
  'object': 'chat.completion',
  'created': 1719587832,
  'model': 'google/gemma-2b-it',
  'usage': {'prompt_tokens': 17, 'completion_tokens': 24, 'total_tokens': 41},
  'choices': [{'index': 0,
    'finish_reason': 'stop',
    'message': {'role': 'assistant',
     'content': "What do you call an AI that's too smart?\n\n... A chatbot with a mind of its own!"}}]}]

In [28]:
model.predict_stream({"messages": messages, "max_tokens": 25})

MlflowException: This model does not support predict_stream method.

Suppose we want to customize our model. For example, let's implement the predict_stream method so our model can stream responses.

In [20]:
import uuid
from typing import Generator, List
import mlflow
from mlflow.pyfunc import ChatResponse, ChatMessage

class GemmaChatModel(mlflow.pyfunc.ChatModel):
    def load_context(self, context):
        # load our previously-saved Transformers pipeline from context.artifacts
        self.pipeline = mlflow.transformers.load_model(context.artifacts["chat_model_path"])

    def preprocess_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
        preprocessed = []
        for i, message in enumerate(messages):
            if message.role == "system":
                preprocessed.append(ChatMessage(role="user", content=message.content))
            else:
                preprocessed.append(message)
            
            # If we just added a user message and it's not the last one, add a blank assistant message
            if message.role in ["user", "system"] and i < len(messages) - 1:
                preprocessed.append(ChatMessage(role="assistant", content=" "))

        return preprocessed

    def predict(self, context, messages, params):
        tokenizer = self.pipeline.tokenizer
        preprocessed_messages = self.preprocess_messages(messages)
        prompt = tokenizer.apply_chat_template(preprocessed_messages, tokenize=False, add_generation_prompt=True)

        # perform inference using the loaded pipeline
        output = self.pipeline(prompt, return_full_text=False, generation_kwargs=params.to_dict())
        text = output[0]["generated_text"]
        id = str(uuid.uuid4())

        # construct token usage information
        prompt_tokens = len(tokenizer.encode(prompt))
        completion_tokens = len(tokenizer.encode(text))
        usage = {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        }

        response = {
            "id": id,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": text},
                    "finish_reason": "stop",
                }
            ],
            "usage": usage,
        }

        return ChatResponse(**response)

    def predict_stream(self, context, messages, params) -> Generator[ChatResponse, None, None]:
        tokenizer = self.pipeline.tokenizer
        preprocessed_messages = self.preprocess_messages(messages)
        prompt = tokenizer.apply_chat_template(preprocessed_messages, tokenize=False, add_generation_prompt=True)

        id = str(uuid.uuid4())
        prompt_tokens = len(tokenizer.encode(prompt))
        accumulated_text = ""

        # perform streaming inference using the loaded pipeline
        for output in self.pipeline(prompt, return_full_text=False, generation_kwargs=params.to_dict(), streaming=True,
                                    max_new_tokens=1000):
            new_text = output["generated_text"]
            accumulated_text += new_text

            # construct token usage information for this chunk
            completion_tokens = len(tokenizer.encode(new_text))
            usage = {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            }

            response = {
                "id": id,
                "model": "MyChatModel",
                "choices": [
                    {
                        "index": 0,
                        "delta": {"role": "assistant", "content": new_text},
                        "finish_reason": None,
                    }
                ],
                "usage": usage,
            }

            yield ChatResponse(**response)

        # Final yield with finish_reason "stop"
        final_response = {
            "id": id,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "delta": {"role": "assistant", "content": ""},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": len(tokenizer.encode(accumulated_text)),
                "total_tokens": prompt_tokens + len(tokenizer.encode(accumulated_text)),
            },
        }

        yield ChatResponse(**final_response)

In [21]:
mlflow.set_experiment("ChatModel")


with mlflow.start_run() as run:
    custom_model_info = mlflow.pyfunc.log_model(
        artifact_path="gemma-text-generation-custom", python_model=GemmaChatModel(),
        artifacts = {"chat_model_path": "file:///Users/daniel.liden/git/llmops-examples/mlruns/387827672116713370/410c3b2e8b3747348c73f15b48a5b82d/artifacts/gemma-text-generation"}
    )

2024/06/28 13:37:27 INFO mlflow.pyfunc: Predicting on input example to validate output


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]



ValueError: Input length of input_ids is 27, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.

In [1]:
import mlflow

model = mlflow.pyfunc.load_model("file:///Users/daniel.liden/git/llmops-examples/mlruns/387827672116713370/191be5952e1c4bc5816165c39b1c2451/artifacts/gemma-text-generation-custom")



Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [4]:
messages = [{"role": "user", "content": "Tell me a short joke about AI."}]
model.predict({"messages": messages, "max_tokens": 25})



{'id': 'db8701a1-b169-4793-8e85-5e57d2f92e49',
 'model': 'MyChatModel',
 'choices': [{'index': 0,
   'message': {'role': 'assistant', 'content': ''},
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 14, 'completion_tokens': 1, 'total_tokens': 15},
 'object': 'chat.completion',
 'created': 1719595408}

In [10]:
def format_messages(messages: List[ChatMessage]) -> str:
    formatted_messages = []
    for message in messages:
        if message.role == "user" or message.role == "system":
            formatted_messages.append(f"Human: {message.content}")
        elif message.role == "assistant":
            formatted_messages.append(f"Assistant: {message.content}")
    return "\n".join(formatted_messages) + "\nAssistant:"

In [11]:
format_messages(messages)

AttributeError: 'dict' object has no attribute 'role'

'Human: Tell me a short joke about AI.\nAssistant:'