# Custom Chat Models with MLflow

## Resources
- [MLflow PyFunc Chat Model Tutorial](https://mlflow.org/docs/latest/llms/transformers/tutorials/conversational/pyfunc-chat-model.html)
- [MLflow PyFunc ChatModel API](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatModel)

## Outline
- What makes a chat model different
- Why defining a Chat Model with Pyfunc is kind of hard
- ChatModel class makes it easier
  - Show equivalent ChatModel and PyFunc model definition
- Add-on: retriever in ChatModel
- Add-on: streaming

## Chat Models with MLflow

OpenAI's model input and output schemas have become the defacto standards among LLM providers. Compatibility with the OpenAI API spec enables models to integrate seamlessly with many different AI tools, evaluation systems, UIs, and more.

This post shows how to make [custom PyFunc models](https://mlflow.org/blog/custom-pyfunc) that conform to the OpenAI API spec.

## The ChatModel Class

As of MLflow 2.11, you can use the `ChatModel` class to define custom chat models. `ChatModel` is a subclass of `PythonModel` that automatically defines input/output signatures that are compatible with the OpenAI API spec. Let's try a simple example and wrap Google's Gemma 2B model with `ChatModel`:

In [None]:
from transformers import pipeline

gemma = pipeline("text-generation", model="google/gemma-2b-it",
                 device_map="auto")

In [None]:
import mlflow

mlflow.set_experiment("ChatModel")

with mlflow.start_run() as run:
    model_info = mlflow.transformers.log_model(
        artifact_path="gemma-text-generation", transformers_model=gemma,
        task="llm/v1/chat"
    )

In [None]:
import mlflow

# Load the previously saved MLflow model
model = mlflow.pyfunc.load_model(model_info.model_uri)

In [None]:
messages = [{"role": "user", "content": "Tell me a short joke about AI."}]
model.predict({"messages": messages, "max_tokens": 25})

In [None]:
model.predict_stream({"messages": messages, "max_tokens": 25})

Suppose we want to customize our model. For example, let's implement the predict_stream method so our model can stream responses.

In [None]:
import uuid
from typing import Generator, List
import mlflow
from mlflow.pyfunc import ChatResponse, ChatMessage

class GemmaChatModel(mlflow.pyfunc.ChatModel):
    def load_context(self, context):
        # load our previously-saved Transformers pipeline from context.artifacts
        self.pipeline = mlflow.transformers.load_model(context.artifacts["chat_model_path"])

    def preprocess_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
        preprocessed = []
        for i, message in enumerate(messages):
            if message.role == "system":
                preprocessed.append(ChatMessage(role="user", content=message.content))
            else:
                preprocessed.append(message)
            
            # If we just added a user message and it's not the last one, add a blank assistant message
            if message.role in ["user", "system"] and i < len(messages) - 1:
                preprocessed.append(ChatMessage(role="assistant", content=" "))

        return preprocessed

    def predict(self, context, messages, params):
        tokenizer = self.pipeline.tokenizer
        preprocessed_messages = self.preprocess_messages(messages)
        prompt = tokenizer.apply_chat_template(preprocessed_messages, tokenize=False, add_generation_prompt=True)

        # perform inference using the loaded pipeline
        output = self.pipeline(prompt, return_full_text=False, generation_kwargs=params.to_dict(), max_new_tokens=100)
        text = output[0]["generated_text"]
        id = str(uuid.uuid4())

        # construct token usage information
        prompt_tokens = len(tokenizer.encode(prompt))
        completion_tokens = len(tokenizer.encode(text))
        usage = {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        }

        response = {
            "id": id,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": text},
                    "finish_reason": "stop",
                }
            ],
            "usage": usage,
        }

        return ChatResponse(**response)

    def predict_stream(self, context, messages, params) -> Generator[ChatResponse, None, None]:
        tokenizer = self.pipeline.tokenizer
        preprocessed_messages = self.preprocess_messages(messages)
        prompt = tokenizer.apply_chat_template(preprocessed_messages, tokenize=False, add_generation_prompt=True)

        id = str(uuid.uuid4())
        prompt_tokens = len(tokenizer.encode(prompt))
        accumulated_text = ""

        # perform streaming inference using the loaded pipeline
        for output in self.pipeline(prompt, return_full_text=False, generation_kwargs=params.to_dict(), streaming=True):
            new_text = output["generated_text"]
            accumulated_text += new_text

            # construct token usage information for this chunk
            completion_tokens = len(tokenizer.encode(new_text))
            usage = {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            }

            response = {
                "id": id,
                "model": "MyChatModel",
                "choices": [
                    {
                        "index": 0,
                        "delta": {"role": "assistant", "content": new_text},
                        "finish_reason": None,
                    }
                ],
                "usage": usage,
            }

            yield ChatResponse(**response)

        # Final yield with finish_reason "stop"
        final_response = {
            "id": id,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "delta": {"role": "assistant", "content": ""},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": len(tokenizer.encode(accumulated_text)),
                "total_tokens": prompt_tokens + len(tokenizer.encode(accumulated_text)),
            },
        }

        yield ChatResponse(**final_response)

In [None]:
mlflow.set_experiment("ChatModel")


with mlflow.start_run() as run:
    custom_model_info = mlflow.pyfunc.log_model(
        artifact_path="gemma-text-generation-custom", python_model=GemmaChatModel(),
        artifacts = {"chat_model_path": "file:///Users/daniel.liden/git/llmops-examples/mlruns/387827672116713370/410c3b2e8b3747348c73f15b48a5b82d/artifacts/gemma-text-generation"}
    )

In [None]:
import mlflow

model = mlflow.pyfunc.load_model(custom_model_info.model_uri)

In [None]:
messages = [{"role": "user", "content": "Tell me a short joke about AI."}]
model.predict({"messages": messages, "max_tokens": 25})

In [None]:
import sys

# Prepare messages
messages = [
    ChatMessage(role="system", content="You are a helpful AI assistant."),
    ChatMessage(role="user", content="Tell me a short story about a brave adventurer, one sentence at a time.")
]

# Set up parameters
params = {
    "max_length": 1000,  # Adjust as needed
    "temperature": 0.7,
    "top_p": 0.9,
}

# Use predict_stream
print("Generating story:")
for response in model.predict_stream(messages, params):
    chunk = response.choices[0].delta.content
    print(chunk, end='', flush=True)
    sys.stdout.flush()  # Ensure output is displayed immediately
