# Phi-4-reasoning Model Deployment and Inference

This notebook demonstrates how to:
1. Deploy the microsoft/Phi-4-reasoning model using MLflow extensions
2. Set up the model for serving
3. Perform inference on both text and image inputs
4. Process images in batch using Spark

## Prerequisites
- Databricks workspace with appropriate permissions
- Access to Hugging Face models
- Sufficient GPU resources for model deployment

## Cluster Configuration
- Runtime: 16.3 ML (includes Apache Spark 3.5.2, GPU, Scala 2.12)
- Node Type: Standard_NC40ads_H100_v5 [H100] (beta)
- 320 GB Memory
- 1 GPU
- 40 Cores

## Install Required Dependencies

Installing necessary packages including:
- OpenAI client
- VLLM for efficient model serving
- MLflow extensions for deployment
- Transformers library with Qwen VL support

In [0]:
PIP_REQUIREMENTS = (
    "openai vllm==0.8.5.post1 optree "
    "git+https://github.com/huggingface/transformers accelerate  "
    "mlflow==2.19.0 "
    "mlflow-extensions "
    "qwen-vl-utils"
)

%pip install {PIP_REQUIREMENTS}
dbutils.library.restartPython()

## Configuration

Set up the necessary configuration parameters for model deployment:
- Catalog and schema for model registration
- Model and endpoint names
- Environment variables for VLLM

In [0]:
# Configuration parameters
CATALOG = "..."
SCHEMA = "..."
MODEL_NAME = "..."
ENDPOINT_NAME = "..."

# Set environment variables for VLLM
import os
# os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
# os.environ['VLLM_USE_V1'] = "0"

In [0]:
PIP_REQUIREMENTS = (
    "openai vllm==0.8.5.post1 optree "
    "git+https://github.com/huggingface/transformers accelerate  "
    "mlflow==2.19.0 "
    "mlflow-extensions "
    "qwen-vl-utils"
)

In [0]:
from huggingface_hub import login
from mlflow_extensions.serving.engines import VLLMEngineProcess
from mlflow_extensions.serving.engines.vllm_engine import VLLMEngineConfig
from mlflow_extensions.databricks.deploy.ez_deploy import EzDeployConfig, ServingConfig, EzDeployVllmOpenCompat

# Replace 'your_huggingface_token' with your actual Hugging Face token
# login()

In [0]:
# Initialize the deployer with VLLM OpenAI compatibility layer
deployer = EzDeployVllmOpenCompat(
  config= EzDeployConfig(
    # Specify the model name/path from Hugging Face
    name="microsoft/Phi-4-reasoning",
    # Use VLLM engine process for serving
    engine_proc=VLLMEngineProcess,
    engine_config=VLLMEngineConfig(
          # Model identifier on Hugging Face
          model="microsoft/Phi-4-reasoning",
          # Maximum sequence length for input
          max_model_len = 30000,
          # Maximum number of images/videos that can be processed
          # VLLM specific configuration flags
          vllm_command_flags={
            # GPU memory utilization target (98%)
            "--gpu-memory-utilization": .95,
            # Disable caching of preprocessed multimedia
            "--disable-mm-preprocessor-cache" : None,
            # Enable automatic tool selection
            "--enable-auto-tool-choice": None,
            # Use hermes parser for tool calls
            "--tool-call-parser" : "hermes",
          },
),
  serving_config=ServingConfig(
      # Minimum memory required for model serving (in GB)
      # Includes model weights, KV cache, overhead and intermediate states
      minimum_memory_in_gb=60,
  ),
  # Use pip requirements defined earlier
  pip_config_override = PIP_REQUIREMENTS.split(" ")
),
  # Register model with fully qualified name in Unity Catalog
  registered_model_name=f"{CATALOG}.{SCHEMA}.{MODEL_NAME}"
)

## Model Registration and Deployment

Download and register the model in Unity Catalog.

In [0]:
# Download and register the model
deployer.artifacts = deployer._config.download_artifacts(local_dir="/tmp/") #this can be volume location as well
deployer._downloaded = True

In [0]:
deployer.register() # Ignore error as this will fail in serverless as there are no GPU's

# Below is the code to deploy the endpoint to model serving

## Model Deployment to Serving Endpoint

Deploy the registered model to a serving endpoint. This will:
1. Create a new serving endpoint with the specified name
2. Load the model into memory
3. Make it available for inference requests

Note: `scale_to_zero=False` means the endpoint will maintain at least one instance running,
which helps reduce cold start times but may incur higher costs.

In [0]:
deployer.deploy(ENDPOINT_NAME, scale_to_zero=False)

## Process Management

### Restarting Model Processes

Sometimes you may need to restart the model processes, for example:
- After making configuration changes
- If the model becomes unresponsive
- To free up GPU memory

The following code will:
1. Kill any existing VLLM processes
2. Kill any Ray processes (used for distributed computing)
3. Kill any multiprocessing processes

Run this cell whenever you need to restart the model processes.

In [0]:
from mlflow_extensions.testing.helper import kill_processes_containing

# Kill existing processes to free up resources
kill_processes_containing("vllm")  # Kill VLLM model serving processes
kill_processes_containing("ray")   # Kill Ray distributed computing processes
kill_processes_containing("from multiprocessing")  # Kill any multiprocessing processes

## Model Serving Setup

Initialize the model for serving and set up the client for inference.
This section will:
1. Set up MLflow registry URI
2. Fetch the latest model version
3. Load the model for serving

In [0]:
import mlflow
from mlflow.tracking import MlflowClient

# Set up MLflow registry
mlflow.set_registry_uri('databricks-uc')

# Initialize MLflow client
client = MlflowClient()

# Get the latest model version
model_name = f"{CATALOG}.{SCHEMA}.{MODEL_NAME}"
latest_version = None

# Iterate through versions to find the latest one
for i in range(1, 10):
    try:
        client.get_model_version(model_name, i)
    except:
        latest_version = i - 1
        break

if latest_version is None:
    raise Exception("Could not determine latest model version")

print(f"Using latest model version: {latest_version}")

# Load the registered model
model_uri = f"models:/{model_name}/{latest_version}"
pyfunc_model = mlflow.pyfunc.load_model(model_uri)
base_url = str(pyfunc_model.unwrap_python_model()._engine._server_http_client.base_url)

print("Model serving base URL:", base_url)

## Inference Examples

Demonstrate model inference capabilities with different types of inputs.

### Text-only Inference

Basic text completion example.

In [0]:
serving_payload = {
    "messages": [
        {
            "role": "system",
            "content": "You are a helpful assistant."
        },
        {
            "role": "user",
            "content": "Hello! how is the weather today ?"
        }
    ],
    "temperature": 1.0,
    "max_tokens": 10000,
}

response = pyfunc_model.predict(serving_payload)
print(response)

### Image Analysis

Example of analyzing an image with the model.

In [0]:
import urllib.request
from PIL import Image
from io import BytesIO
from openai import OpenAI

# Initialize OpenAI client
client = OpenAI(
    base_url=f"{base_url}/v1",
    api_key="DUMMY"
)

# Load and display test image
image_url = "https://www.arsenal.com/sites/default/files/styles/large_16x9/public/images/saka-celeb-bayern.png?h=3c8f2bed&auto=webp&itok=Twjeu8tug"
with urllib.request.urlopen(image_url) as url:
    img = Image.open(BytesIO(url.read()))
display(img)

# Perform image analysis
response = client.chat.completions.create(
    model="default",
    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Which football team does this player belong to?"},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": image_url,
                        "detail": "high"
                    },
                },
            ],
        }
    ],
    temperature=0.0,
    max_tokens=150,
)

print(response.choices[0].message.content.strip())

In [0]:
from pyspark.sql.functions import pandas_udf, regexp_replace
import pandas as pd
from io import BytesIO 
import base64
from openai import OpenAI
from PIL import Image

# Initialize OpenAI client with local endpoint
client = OpenAI(
        base_url=f"{'http://0.0.0.0:9989'}/v1",
        api_key="DUMMY")

# Path to local image file
img_path = "/Volumes/samantha_wise/gsk_vlm_poc/images/Goldfish-2-e1724099193229.png" # change here

# Read the local image file into bytes and convert to base64
# Steps:
# 1. Open image file in binary mode and read bytes
# 2. Create BytesIO buffer to hold the bytes in memory
# 3. Encode bytes to base64 string for API transmission
# 4. Decode to UTF-8 string since API expects text
# This converts the binary image data into a text format that can be sent in the API request
with open(img_path, 'rb') as f:
    # Create BytesIO object from image bytes
    image_file = BytesIO(f.read())
    # Convert bytes to base64 string for API
    # Base64 encoding ensures binary image data can be transmitted as text
    image_base64 = base64.b64encode(image_file.getvalue()).decode('utf-8')

In [0]:
# Make API call with base64 encoded image
# The image is passed as a data URL in the format:
# data:image/png;base64,<base64_string>
# This format allows embedding binary image data directly in the request
response = client.chat.completions.create(
    model="default",
    messages=[
        {
            "role": "user", 
            "content": [
                    {
                         "type": "text", 
                         "text": """OCR and give details and look at everything and do not hallucinate and think carefully """
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}",  # Pass base64 image as data URL
                            'detail': 'high'
                        }
                    }
            ]
        }
    ],
    max_tokens = 5000,
    temperature = 0.1,
    top_p = 0.95,
)

print(response.choices[0].message.content.strip())
#"OCR this image, provide the context of the image and return the output in table format with 4 columns namely PRODUCT_TYPE, PRODUCT_TEXT, PRODUCT_NUMBER and LEGEND"

#Load files from Volumes

In [0]:
df_raw = (
    spark.readStream.format("cloudFiles")
    .option("cloudFiles.format", "binaryFile")
    .option("pathGlobfilter", f"*.jpg")
    .load(f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}")
)

In [0]:
TABLE_NAME = "table_name" # change here

df_img = spark.table(f"{CATALOG}.{SCHEMA}.{TABLE_NAME}")
display(df_img)

In [0]:
from pyspark.sql.functions import pandas_udf, regexp_replace
import pandas as pd

prompt = "This image contains a human. Your task is to tell me what this person is doing and try to identify who they are." 


@pandas_udf("string")
def classify_img(images: pd.Series) -> pd.Series:
    def classify_one_image(img):
        client = OpenAI(
            base_url=f"{base_url}/v1",
            api_key="DUMMY"
        )

        image_file = BytesIO(img)
        image_base64 = base64.b64encode(image_file.read()).decode('utf-8')

        response = client.chat.completions.create(
            model="default",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/png;base64,{image_base64}"}
                        }
                    ]
                }
            ]
        )
        return response.choices[0].message.content.strip()
    
    return pd.Series([classify_one_image(img) for img in images])

# Example usage with Spark DataFrame
df_inference = df_img.repartition(4).withColumn("vLLM_predict", classify_img("content"))
display(df_inference)

## Cleanup

Clean up processes when needed.

In [0]:
from mlflow_extensions.testing.helper import kill_processes_containing

kill_processes_containing("vllm")
kill_processes_containing("ray")
kill_processes_containing("from multiprocessing")