# Part 3: Custom MLflow Model

MLflow for GenAI Guide Overview

This is part three of a four-part guide for using MLflow to experiment with and deploy your generative AI projects.

In part one, we started with the problem of generating social media posts following the style of a set of example posts. We used MLautologging and tracing to keep track of our early informal experiments as we developed a simple prototype. In part two, we used MLflow's evaluation tools to set up and run a more structured experiment comparing the performance of a couple of different models and prompts.

Now we are ready to prepare our GenAI application for production. In this part of the guide, we will:

- Wrap our model logic in a custom MLflow `PyFunc` model, allowing us to keep track of model versions and configurations, specify how we would like tracing handled in different environments, and validate the model prior to serving.
- Create a staging and production model in the model registry and register our custom model as a model version in the registry.

## Prerequisites and Setup

This guide assumes you have read and worked through parts [one](../part_1) and [two](../part_2), and will frequently reference code and concepts introduced there. It also assumes you have followed the setup instructions detailed there.

## Encapsulating our Model Logic

At this point, we have prototyped our GenAI application and refined that prototype through the evaluation process, settling on a model and prompt that generate posts that are accurate and that effectively copy the style of the example posts.

Now that we have a working application, we need to prepare it for production use. We need to ensure that the application's interface is standardized and well defined so any downstream parts of the application, such as a frontend GUI, can effectively use the model. We also need to make sure the model environment and dependencies are recorded and packaged with the model in order to make sure the model can be deployed across different environments without dependency conflicts. Furthermore, encapsulating our application logic in a custom PyFunc model will allow us to keep track of model versions and configurations, specify how we would like tracing handled in different environments, and validate the model prior to serving.

[MLflow models](/model) address these needs by providing a standardized way to save model logic, configurations, and dependencies. They also provide a standardized interface and deployment system, ensuring the model can be deployed and used across different environments.

| **Component** | Without MLflow Model | With MLflow Model |
|-----------|--------------|---------------------------|
| **Code Organization** | Scattered across files, ad-hoc structure | Single, well-defined class with standard interface |
| **Dependencies** | Manually tracked and installed | Automatically captured and packaged with model |
| **Configuration** | Hard-coded or managed separately | Version-controlled with model artifacts |
| **Tracing** | Autologging or ad-hoc implementation | Configurable, standardized tracing controls |
| **Input/Output** | Undefined schema, inconsistent validation | Type-hinted interface with automatic validation |
| **Deployment** | Custom deployment logic needed | Standard MLflow deployment options |
| **Versioning** | Manual version tracking | Automatic versioning through Model Registry |
| **Environment** | "Works on my machine" | Reproducible across environments |

### ChatModel or PyFunc?

MLflow offers the [ChatModel](/llms/chat-model-intro/) class for setting up custom GenAI models with a chat interface. While this may seem like the obvious choice for our application, we will actually be using a custom [PyFunc](/model/python_model) model instead. Though our application *uses* a chat model, it is not, itself, a chat model: the inputs and outputs are not lists of messages.

The ChatModel interface is a great choice for developing and deploying conversational models with standardized input/output schemas following the OpenAI spec. PyFunc models, on the other hand, offer full control over the model interface.

### Defining the Custom Model

There are a few key considerations we will make when defining the custom model.

1. **Tracing:** We want to make sure tracing is easily configurable. There may, for example, be circumstances where tracing is not permissible in a production environment because of user data privacy concerns. To address cases like this, we will make tracing configurable via an environment variable.

2. **Configuration:** We will make some of the application's configuration, such as the system prompt and model provider, available to the model via a configuration dictionary. This is a good idea for parameters we expect we may want to change in the future. It will allow us to log updated versions of the model without needing to change the custom PyFunc model code.

3. **Models from Code:** We will save our model logic in a separate file called `model.py` and include the line `set_model(SocialPoster())` at the end of the file. When we log the model, we will set the `code_path` argument to `model.py`. This is the [models from code](/model/models-from-code) approach to logging models, which is a good fit for GenAI applications using LLM APIs, where there are no model weights that need to be serialized.

4. **Input and Output Schema:** MLflow 2.20 introduced the ability to define model signatures for custom PyFunc models [using type hints](/model/python_model#model-signature-inference-based-on-type-hints) in the `predict` method. We will use this approach to define the input and output schema for our model.

#### The PythonModel Class

In [0]:
!pip install markdownify openai

Collecting markdownify
  Obtaining dependency information for markdownify from https://files.pythonhosted.org/packages/64/11/b751af7ad41b254a802cf52f7bc1fca7cabe2388132f2ce60a1a6b9b9622/markdownify-1.1.0-py3-none-any.whl.metadata
  Downloading markdownify-1.1.0-py3-none-any.whl.metadata (9.1 kB)
Collecting beautifulsoup4<5,>=4.9 (from markdownify)
  Obtaining dependency information for beautifulsoup4<5,>=4.9 from https://files.pythonhosted.org/packages/f9/49/6abb616eb3cbab6a7cca303dc02fdf3836de2e0b834bf966a7f5271a34d8/beautifulsoup4-4.13.3-py3-none-any.whl.metadata
  Downloading beautifulsoup4-4.13.3-py3-none-any.whl.metadata (3.8 kB)
Collecting soupsieve>1.2 (from beautifulsoup4<5,>=4.9->markdownify)
  Obtaining dependency information for soupsieve>1.2 from https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl.metadata
  Downloading soupsieve-2.6-py3-none-any.whl.metadata (4.6 kB)
Downloading markdowni

In [0]:
%restart_python

model.py


Let's walk through the code in detail.

- The `SocialPoster` class inherits from `PythonModel`. This is the base class for all custom PyFunc models. It requires a `predict` method that takes a `context` and a list of `model_input` and returns a list of `model_output`.

- The core application logic is implemented in the `predict` method. This method:
  1. Checks whether tracing is enabled and sets it according to the `tracing_enabled` attribute using the `mlflow.tracing.enable()` and `mlflow.tracing.disable()` methods.
  2. Starts a new tracing span with the name "predict" and the type "CHAIN". This span will be the parent span for the rest of the model's execution.
  3. Assembles the prompt from the example posts, context, and additional instructions using the `_webpage_to_markdown` and `_generate_prompt` helper functions.
  4. Generates a post using the using the `_generate_post` helper function.
  5. Returns the generated post.

- The `predict` method calls the `_webpage_to_markdown`, `_generate_prompt`, and `_generate_post` helper functions that define specific steps in the application logic. We defined these functions in part 1 of this guide. Here, we add the `@mlflow.trace` decorator to the functions to make sure they are traced, giving us visibility into the application's execution flow and enabling us to diagnose any issues that may arise.

- The `load_context` method is called once when the model is loaded and is responsible for initializing model-specific attributes and resources. In our case, it loads configuration values from the model context, sets up tracing based on an environment variable, and initializes the appropriate client based on the configured model provider. This initialization approach ensures that our model has all necessary resources and configurations ready before any predictions are made, while keeping sensitive information like API keys separate from the model artifacts.


In [0]:
import mlflow
from mlflow.pyfunc import PythonModel
from mlflow.models import set_model
from markdownify import markdownify
import requests
import os
from openai import OpenAI
from pydantic import BaseModel
from typing import List

class SocialPostInput(BaseModel):
    example_posts: List[str]
    context_url: str
    additional_instructions: str

class SocialPostOutput(BaseModel):
    post: str

class SocialPoster(PythonModel):
    def __init__(self):
        self.tracing_enabled = False

    @mlflow.trace(span_type="FUNCTION")
    def _webpage_to_markdown(self, url):
        # Get webpage content
        response = requests.get(url)
        html_content = response.text

        # Convert to markdown
        markdown_content = markdownify(html_content)

        return markdown_content

    @mlflow.trace(span_type="FUNCTION")
    def _generate_prompt(
        self, example_posts, context, additional_instructions
    ):
        """Generate a prompt for the LLM based on the example posts, context, and additional instructions."""
        example_posts = "\n".join(
            [f"Example {i+1}:\n{post}" for i, post in enumerate(example_posts)]
        )
        prompt = self.prompt_template.format(
            example_posts=example_posts,
            context=context,
            additional_instructions=additional_instructions,
        )

        formatted_prompt = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": prompt},
        ]

        return formatted_prompt

    @mlflow.trace(span_type="LLM")
    def _generate_post(self, messages):
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            max_tokens=1000,
        )
        return response.choices[0].message.content

    def load_context(self, context):
        self.system_prompt = context.model_config["system_prompt"]
        self.prompt_template = context.model_config["prompt_template"]
        self.model_provider = context.model_config["model_provider"]
        self.model_name = context.model_config["model_name"]
        self.tracing_enabled = os.getenv("MLFLOW_TRACING_ENABLED", "false").lower() == "true"

        if self.model_provider == "openai":
            self.client = OpenAI()
        elif self.model_provider == "google":
            self.client = OpenAI(
                base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
                api_key=os.getenv("GEMINI_API_KEY")
            )
        else:
            raise ValueError(f"Unsupported model provider: {self.model_provider}")

    def predict(self, context, model_input: list[SocialPostInput]) -> list[SocialPostOutput]:
        # only one input for illustration purposes
        model_input = model_input[0].model_dump()
        # check whether tracing is enabled and set it according to self.tracing_enabled
        if not mlflow.tracing.provider.is_tracing_enabled() == self.tracing_enabled:
            mlflow.tracing.enable() if self.tracing_enabled else mlflow.tracing.disable()

        with mlflow.start_span(name="predict", span_type="CHAIN") as parent_span:
            parent_span.set_inputs(model_input)
            example_posts = model_input.get("example_posts")
            context_url = model_input.get("context_url")
            markdown_context = self._webpage_to_markdown(context_url)
            additional_instructions = model_input.get("additional_instructions")

            prompt = self._generate_prompt(example_posts, markdown_context, additional_instructions)
            post = self._generate_post(prompt)
            parent_span.set_outputs({"post": post})
        return [{"post": post}]

set_model(SocialPoster())

### Setting up the Model Configuration

Now that we have defined the custom model, we need to set up the model configuration. Though optional, using a configuration that is separate from the application logic gives us an easy way to see and update aspects of the model's behavior without needing to change the model code.

|  | Configuration as part of Model Code | External Configuration |
|------------|------------------------|----------------------|
| **Flexibility** | Configuration changes require updating core model logic in code | Frequently updated elements like prompts and model settings can be modified without code changes |
| **Maintainability** | Configuration mixed with application logic can be difficult to identify and modify | Clear separation of concerns between model logic and configuration |
| **Reproducibility** | Configuration choices buried in code | Configuration choices easy to identify in each MLflow run |

This separation of configuration from code is particularly valuable in GenAI applications, where we often need to experiment with different prompts, model parameters, and even model providers. The external configuration approach makes it easier to track these experiments while keeping the core model logic clean and maintainable.

#### Writing the Configuration Dict

The `load_context` method we defined expects a `model_config` dictionary with the following keys:

- `system_prompt`: The system prompt to use for the model.
- `prompt_template`: The prompt template to use for the model.
- `model_provider`: The model provider to use for the model.
- `model_name`: The name of the model to use for the model.

Consistent with the evaluation-based results from part 2, we will configure the model to use the `gemini-2.0-flash-exp` model from Google along with a detailed system prompt. We will save all of these expected elements in a `config` dictionary, which we will save along with the model during model logging.


In [0]:
system_prompt = """You are a social media content specialist with expertise in matching writing styles and voice across platforms. Your task is to:

1. Analyze the provided example post(s) by examining:

   - Writing style, tone, and voice
   - Sentence structure and length
   - Use of hashtags, emojis, and formatting
   - Engagement techniques and calls-to-action

2. Generate a new LinkedIn post about the given topic that matches:

   - The identified writing style and tone
   - Similar structure and formatting choices
   - Equivalent use of platform features and hashtags
   - Comparable engagement elements

3. Return only the generated post, formatted exactly as it would appear on LinkedIn, without any additional commentary or explanations."""

prompt_template = """
example posts:

{example_posts}

context:

{context}

additional instructions:

{additional_instructions}
"""

config = {
    "system_prompt": system_prompt,
    "prompt_template": prompt_template,
    "model_provider": "google",
    "model_name": "gemini-2.0-flash-exp",
}

### Logging the model

Now that we have defined the `SocialPoster` class (saved in `model.py`) and the `config` dictionary, we have everything we need to log the model. We will use the `mlflow.pyfunc.log_model` function to log the model.


In [0]:
code_path = "model.py"

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        "social_poster",
        python_model=code_path,
        model_config=config,
    )

2025/04/13 20:59:08 INFO mlflow.models.signature: Inferring model signature from type hints


Uploading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]


This code logs the model and saves it to the MLflow tracking server. It logs the model code defined in `model.py` and saves the model configuration in the `config` dictionary. You can inspect all of this information in the MLflow UI.

**Logged Model in the MLflow UI**

In the MLflow UI, we can inspect the model code, dependencies, and configuration.

*Model Code*

![Logged Model](/images/llms/mlflow-for-genai/9_logged_model.png)

*Model Configuration*

![Logged Model Configuration](/images/llms/mlflow-for-genai/10_model_config.png)

Though we have successfully logged the model, we still aren't certain that it will work as expected in production. Let's verify that the model is ready for deployment.

## Validating the model prior to deployment

We need to be sure that our model will work as expected once we deploy it to a serving endpoint. To this end, we will construct a simple input example in the format we intend to use in production and validate it with the `validate_serving_input` method.


In [0]:
import os
from mlflow.models import convert_input_example_to_serving_input

sample_input = [{
    "example_posts": ["Example 1: This is an example post.", "Example 2: This is another example post."],
    "context_url": "https://www.example.com",
    "additional_instructions": "The post should be concise and to the point..."
}]

serving_payload = convert_input_example_to_serving_input(input_example=sample_input)

model_uri = model_info.model_uri


# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

Which returns:


In [0]:
[{'post': 'Example Domain. This domain is for illustrative examples. You can use it freely. More info: [https://www.iana.org/domains/example](https://www.iana.org/domains/example)\n'}]


[{'post': 'Example Domain. This domain is for illustrative examples. You can use it freely. More info: [https://www.iana.org/domains/example](https://www.iana.org/domains/example)\n'}]

In [0]:
from mlflow.models import validate_serving_input

os.environ["GEMINI_API_KEY"] = dbutils.secrets.get(scope="mlflow_genai", key="GEMINI_API_KEY")

os.environ["MLFLOW_TRACING_ENABLED"] = "true"
validate_serving_input(model_uri, serving_payload)

Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

[{'post': 'Example Domain is reserved for illustrative examples. Use it freely in your documentation without asking for permission.\n\nLearn more about its purpose and usage guidelines: [https://www.iana.org/domains/example](https://www.iana.org/domains/example)\n\n#ExampleDomain #Documentation #IANA #ReservedDomain #Examples\n'}]

Trace(request_id=tr-e37713e9c1b343d1b7c0a53259a2cad6)


## Registering the Model

Now that we have created and validated our custom PyFunc model, let's register it to the MLflow model registry. The model registry is a great way to manage models and model versions. It will allow us to track the history of our model versions, define separate staging and production models, assign model aliases, and more.

First, a quick note on terminology:

- A **registered model** is a container for a group of model versions. It does not need to have any actual model artifacts in it.
- A **model version** is registered to a registered model. It contains the actual model artifacts. One registered model can have many model versions.

For example, you might have a "sentiment-analyzer" registered model that contains multiple versions as you iterate and improve your model over time. You might also have separate registered models for different environments (like staging and production) to manage the deployment lifecycle of your models.

Now, let's register empty models for our staging and production environments:


In [0]:
from mlflow import MlflowClient
client = MlflowClient()

client.create_registered_model("mlflow_lightening_session.dev.social-ai-staging")
client.create_registered_model("mlflow_lightening_session.dev.social-ai-production")

<RegisteredModel: aliases={}, creation_timestamp=1744578933645, description='', last_updated_timestamp=1744578933645, latest_versions=None, name='mlflow_lightening_session.dev.social-ai-production', tags={}>


Again, these registered models are *containers* and do not yet contain any model artifacts. These are simply used to organize our model versions. In the MLflow UI, we can see the registered models without associated versions.

![Registered Models](/images/llms/mlflow-for-genai/12_registered_models.png)

Now we can register our logged model version to the staging model:


In [0]:
mv = client.create_model_version(
    name="mlflow_lightening_session.dev.social-ai-staging",
    source=model_info.model_uri)



Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

Uploading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

In [0]:
mv.name

'mlflow_lightening_session.dev.social-ai-staging'

In [0]:
client.set_registered_model_alias(
    name="mlflow_lightening_session.dev.social-ai-staging",
    alias="latest-model",
    version=mv.version,
        )

In [0]:
model_uri = f"models:/mlflow_lightening_session.dev.social-ai-staging@latest-model"
model = mlflow.pyfunc.load_model(model_uri)
model.predict(sample_input)

Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

[{'post': 'Example Domain is reserved for illustrative examples. Use it freely in your documentation without asking! 📚\n\n[More information...](https://www.iana.org/domains/example)\n\n#ExampleDomain #Documentation #OpenSource #FreeResources\n'}]

Trace(request_id=tr-8562954008b84c399e40ea3c0513a230)

## Conclusion
In this part of the guide, we have encapsulated our model logic in a custom PyFunc model, set up tracing, validated the model, and registered it with MLflow. Specifically, we have seen how to:

Define a custom PyFunc model and set up tracing.
Define the PyFunc model predict method's signature with a Pydantic class as a type hint.
Set up tracing to be configurable via an environment variable.
Set up an external configuration to manage configurable aspects of the model we expect we might want to change in the future.
Validate the model prior to deployment.
Register the model with MLflow's model registry.
In the next and final part of this guide, we will deploy the model to a staging environment, promote it to production, introduce a new challenger model to evaluate against the production model, and then promote the challenger model to production.

