## Wrap an `[KIE]agent` in a Custom Pyfunc + Explore Custom Evals



### Workspace Setup:

- Dependencies
- UC Paths
- Environment Variables

In [0]:
!pip install mlflow=3.1.4 databricks-agents=1.2.0 cloudpickle>=3.1.1

dbutils.library.restartPython()

In [0]:
## Widgets for CATALOG_NAME, SCHEMA_NAME, and MODEL_NAME
# Please replace with your own values

dbutils.widgets.text("CATALOG_NAME", "mmt", "Catalog Name")
dbutils.widgets.text("SCHEMA_NAME", "bricks", "Schema Name")
dbutils.widgets.text("MODEL_NAME", "tbct_wrapped_KIEagent", "Model Name")

# Retrieve the values from the widgets
catalog_name = dbutils.widgets.get("CATALOG_NAME")
schema_name = dbutils.widgets.get("SCHEMA_NAME")
model_name = dbutils.widgets.get("MODEL_NAME")


## CREATE Catalog/Schema etc. if not already available...
# Create catalog if it does not exist
spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog_name}")

# Create schema if it does not exist
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{schema_name}")

In [0]:
import mlflow
import mlflow.pyfunc
import pandas as pd
import requests
import os

## Set MLflow experiment
# Get current user from Databricks context
current_user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

# Set experiment path dynamically
experiment_path = f"/Workspace/Users/{current_user}/agentbricks_utilities_kie_test"
mlflow.set_experiment(experiment_path)

## Set DATABRICKS_TOKEN
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
os.environ['DATABRICKS_TOKEN'] = token 
## for quick testing -- best practice to use Service Principal / PAT OR refactor to use with WorkspaceClient SDK -- serving.endpoint.query()

## The assumption here is that you have already created an KIE endpoint with the desired response output.
# We will wrap the {kie}Agent endpoint_url inside an MLflow Custom Pyfunc
endpoint_url = f"https://{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

---    

### Define a MLflow Custom Pyfunc `KIEwrapper` Class

In [0]:
import os
import requests
import pandas as pd
import mlflow.pyfunc
import json

## type hints
from typing import List, Dict

class KIEwrapper(mlflow.pyfunc.PythonModel):
    """MLflow wrapper for KIE endpoints using chat message format."""

    def __init__(
        self,
        endpoint_url: str,
        token: str = None
    ):
        self.endpoint_url = endpoint_url
        self.token = token or os.environ.get('DATABRICKS_TOKEN')
        if not self.token:
            raise ValueError("No token found in environment")
        self.headers = {
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/json"
        }

    def load_context(self, context):
        # Only use for context-dependent setup
        pass

    def predict(
        self,
        context,
        model_input: pd.DataFrame 
        ) -> List[str]:
        try:
            if isinstance(model_input, pd.DataFrame):
                texts = model_input['workorder_notes'].tolist()
            elif isinstance(model_input, list):
                texts = [str(item) for item in model_input]
            else:
                texts = [str(model_input)]

            results = []
            for text in texts:
                payload = {
                    "messages": [
                        {
                            "role": "user",
                            "content": text
                        }
                    ]
                }
                response = requests.post(
                    self.endpoint_url,
                    headers=self.headers,
                    json=payload,
                    timeout=360
                )
                if response.status_code == 200:
                    result = response.json()
                    if 'choices' in result and len(result['choices']) > 0:
                        extracted_content = result['choices'][0]['message']['content']
                        results.append(extracted_content)
                    else:
                        results.append(result)
                else:
                    error_msg = f"HTTP {response.status_code}: {response.text}"
                    results.append({"error": error_msg})
            return results
        except Exception as e:
            return [{"error": str(e)}]

### Create or Load sample data for logging `KIE`wrapper model

This example is specific to an KIE agent created for the use-case where certain attributes of device and maintenance workorder logs require extraction for downstream modelling. Sample input examples are used for illustrative purposes here.

In [0]:
# Sample de-identified data for illustration 

import random
import pandas as pd

sample_df = pd.DataFrame(
    [
        {"workorder_id": str(random.randint(100000, 999999)),
        "workorder_notes": """reported_condition: Reported Condition not listed additional_diagnosis: Component Wear problem_description: 2025-02-18 12:11:26 DRD 02/18/2025 Machine is very loud while in use per management Roy said to remove and metal shavings were found in centrifuge. work_performed_desc: XXX 2025-02-18 12:13:44 Unable to remove centrifuge to observe where shavings are coming from due to stripped centrifuge screw. XXX 2025-02-19 12:14:43 Requested photos were uploaded XXX 2025-03-20 08:38:22 Just following up with this. Do we have an ETA of when a Tech will be scheduled to come and get this taken care of? XXX 2025-03-27 08:16:33 Noticed AC Hook was mising. Part ordered. XXX 2025-04-03 09:29:15 Should we use ScrewGrab to remove the stripped screw and be available to take a look at the centrifuge? OO 4/23/25 Removed stripped screw using Dremel tool. Center did not have a spare centrifuge or motor, so spares have been ordered and the center technician can complete the repair. EMB 4/23/25 XXX 2025-04-24 13:49:06 Replaced AC Hook, Centrifuge and Centrifuge Motor per R&R. Performed Multifunction CCA Auto-Test and Fluid Test with pasing results."""},

        {"workorder_id": str(random.randint(100000, 999999)),
         "workorder_notes": """reported_condition: Blood Spill additional_diagnosis: problem_description: 2025-01-12 12:11:06 Loop break occurred during donation. While completing routine fluid spill cleaning, blood was found underneath the leak detector. XXX 2025-03-28 15:51:10 Damaged centrifuge work_performed_desc: XXX 2025-01-12 12:15:47 Leak detector was removed and replaced on device. Multifunction test was ran and failed due to high presure leak test. - AW 1/12/25 XXX 2025-03-09 07:10:22 High presure leak troubleshooting steps were performed. High presure air system was presurized to 98 PSI. After 3 minutes, les than 3 PSI was lost. Multifunction auto-test was ran again and it still failed due to high presure leak test. XXX 2025-03-18 06:03:10 High presure leak test troubleshooting steps were performed as followed. High presure air system was presurized to 98 PSI. After 3 minutes les than 3 PSI was lost. Proceeded to the les than 3 PSI lost section. The draw and return pump tracks were actuated and neither observed a los of greater than 3 PSI. All valves on the soft casette housing were actuated and none of those observed a los greater than 3 PSI. The saline valve when actuated never reached a los greater than 3 PSI. The collection valve when actuated observed a los greater than 3 PSI. The pneumatic tubing for the valve was reseated and when energized the collection valve still observed a los greater than 3 PSI. Module will need to be replaced per high presure leak test troubleshooting steps. XXX 2025-03-28 15:39:12 Replaced collection valve following R&R; Multifunction test pased XXX 2025-03-28 16:08:21 Replaced Centrifuge following R&R XXX 2025-03-28 16:13:30 Performed another multifunction test after replacing centrifuge; Test pased XXX 2025-03-28 16:26:52 Fluid test pased"""},

        {"workorder_id": str(random.randint(100000, 999999)),
         "workorder_notes": """reported_condition: Alarm Codes additional_diagnosis: problem_description: 2025-02-19 13:03:21 the machine alarmed three times but alarmed for the fourth time removed from service 3106 alarm XXX 2025-03-02 11:59:02 Machine losing presurization due to faulty return pump asembly. Return pump rotor caused the device to leak air presure upon closing tracks due to wear. work_performed_desc: XXX 2025-02-20 14:26:13 Navigated to the multifunction CCA section of the hardware tab. Presurized the high-presure air system to 98 psi. Waited 3 minutes. Observed how much presure is lost. The los is les than 3 psi. Performed Multifunction CCA autotest with failed results. High presure system leak test failed. Escalated to [CompanyName] per Service Manual. XXX 2025-03-02 11:59:02 Removed and replaced return pump asembly per service manual. Multifunction and return pump auto-tests performed with both tests pasing. Centrifuge presure sensors calibrated. XXX 2025-03-02 15:29:49 Fluid test complete with pasing results."""},
        
        {"workorder_id": str(random.randint(100000, 999999)),
         "workorder_notes": """reported_condition: Installation Request additional_diagnosis: Not Listed problem_description: 2025-04-14 09:35:16 AC Pump out of box failure. KJ 4/14/25 XXX 2025-04-17 14:43:47 Installation completed with all pasing results. When running the fluid test, the centrifuge was louder than usual. The fluid test still pased. Torqued the centrifuge per R & R and inspected it for abnormalities. Nothing unusual noted. Another fluid test was performed and the centrifuge was still louder than expected. Fluid test pased. work_performed_desc: XXX 2025-04-17 14:47:30 Escalated for troubleshooting advice on a loud centrifuge. Should we order a new centrifuge? XXX 2025-04-17 14:58:34 When centrifuge was commanded to maximum speed on hardware tab, no abnormal noise occurred. The noise only occurs with a separation set loaded in the machine. Inspected the rollers for damage. All loop rollers are spinning freely. Loop holder has no abnormalities noted. XXX 2025-04-18 07:33:08 Hello, you can reference WO-00420320 for details on AC pump out of box failure. Will work on uploading those pictures you requested. Thank you XXX 2025-04-18 08:44:17 The centrifuge hinge is slightly loose but nothing out of the ordinary. Compared the loosenes to other machines and it has the same amount of give as other in-service machines do. Pictures uploaded XXX 2025-04-18 08:47:38 Using the hardware tab, commanded the centrifuge to spin at 2400 RPM with a separation set loaded in the filler housing. No abnormal or loud noise occurred. No abnormal vibration. Loud noise only occurs when fluid enters the centrifuge."""},
        
        {"workorder_id": str(random.randint(100000, 999999)),
        "workorder_notes": """reported_condition: Preventive maintenance is due additional_diagnosis: problem_description: Preventive Maintenance is due Damaged centrifuge, return pump rotor, and door liner. XXX 2/25/25 work_performed_desc: PM Completed per manufacturers instructions. XXX 2/25/25 Replaced centrifuge & motor. Replaced return pump rotor. Replaced door liner. XXX 2/25/25 250g weight used: S/N: 141258 Cal due date: 4/30/25"""},
     ] 
    )

In [0]:
display(sample_df)

In [0]:
train_df = sample_df.sample(frac=0.6, random_state=42)
test_df = sample_df.drop(train_df.index)
train_df, test_df

In [0]:
# Instantiate your custom PythonModel
local_model = KIEwrapper(endpoint_url,token) ## wrapped_KIEagent # KIEwrapper()

# Test the predict method directly
response_output = local_model.predict(context=None, model_input=test_df)

response_output

In [0]:
[json.loads(r) for r in response_output]

### Infer MLflow model input/output Signatures

In [0]:
from mlflow.models.signature import ModelSignature, infer_signature

# signature = infer_signature(inputs=test_df, outputs=response_output)
signature = infer_signature(model_input=test_df, 
                            model_output=local_model.predict(context=None, model_input=test_df)
                           )

In [0]:
signature

### MLflow log custom pyfunc wrappedKIEagent

In [0]:
# endpoint_url = "https://adb-830292400663869.9.azuredatabricks.net/serving-endpoints/kie-935bbd40-endpoint/invocations"

## instantiate KIEwrapper()
wrapped_KIEagent = KIEwrapper(endpoint_url, token) 

In [0]:
with mlflow.start_run() as run:
    
    model_info = mlflow.pyfunc.log_model(        
        name="bricks_kie_agent",  
        python_model=wrapped_KIEagent,
        input_example=test_df,
        pip_requirements=["mlflow==3.1.4", "requests>=2.25.0", "pandas>=1.3.0", "cloudpickle>=3.1.1"],
        signature=signature,
        # registered_model_name="{catalog}.{schema}.{model_name}" ## can directly register it if confident -- it's advisable to test before doing it (for shipping the model registration code after testing you can register to UC directly)
    )
    
    print("\nModel logged successfully!")
    
    # Load and test
    model = mlflow.pyfunc.load_model(model_info.model_uri)
    response = model.predict(test_df)
    
    print(f"\nmodel_response: {response}")
    print(f"\nmodel uri: {model_info.model_uri}")
    print(f"\nrun_id: {model_info.run_id}")

print(f"\nmlflow experiment_id: {run.info.experiment_id} -- model logging completed successfully!")

### Model Response Output 
-- example parsing

In [0]:
import json
from pyspark.sql import functions as f, types as t

## convert response to sparkDF for parsing
response_sdf = spark.createDataFrame(pd.DataFrame(response, columns = ['response']))
# display(response_sdf)


#'response_sdf' has a column 'response' with JSON strings -- extract schema from first row
json_rdd = response_sdf.limit(1).select("response").rdd.map(lambda row: row.response)
inferred_df = spark.read.json(json_rdd)
# display(inferred_df)

## apply schema in from_json extraction to all rows
parsed_df = response_sdf.withColumn("parsed",
                                    f.from_json(f.col("response"), 
                                                inferred_df.schema
                                                )
                                    )
parsed_df = parsed_df.select("*", "parsed.*")                                    
display(parsed_df)

### Register logged model to UC

In [0]:
# Register the model in Unity Catalog : "{catalog}.{schema}.{model_name}"
uc_model_name = f"{catalog_name}.{schema_name}.{model_name}"

registered_model = mlflow.register_model(
    model_uri=model_info.model_uri,
    name=uc_model_name
)

print(f"Model registered in UC as: {uc_model_name}")

In [0]:
# registered_model.version

In [0]:
def get_latest_uc_model_version(uc_model_name):
    from mlflow.tracking import MlflowClient

    client = MlflowClient()
    model_versions = client.search_model_versions(f"name='{uc_model_name}'")
    latest_uc_model_version = max([int(mv.version) for mv in model_versions])
    print(f"Latest version for {uc_model_name}: {latest_uc_model_version}")

    return latest_uc_model_version

# latest_uc_model_version = get_latest_uc_model_version(uc_model_name)    

### Test inferencing with UC registered model

In [0]:
import mlflow

uc_model_uri = f"models:/{uc_model_name}/{get_latest_uc_model_version(uc_model_name)}"
pyfunc_model = mlflow.pyfunc.load_model(model_uri=uc_model_uri)

# The model is logged with an input example
input_data = pyfunc_model.input_example

# Verify the model with the provided input data using the logged dependencies.
# For more details, refer to:
# REF https://mlflow.org/docs/latest/models.html#validate-models-before-deployment

# Use the loaded model's predict method directly
predictions = pyfunc_model.predict(input_data)

predictions

In [0]:
import mlflow
# import pandas as pd

uc_model_uri = f"models:/{uc_model_name}/{get_latest_uc_model_version(uc_model_name)}"
loaded_model = mlflow.pyfunc.load_model(model_uri = uc_model_uri)
uc_result = loaded_model.predict(test_df) # list of dicts

# Convert the result to a Pandas DataFrame
uc_result_df = pd.DataFrame({'prediction': uc_result})

display(uc_result_df)

#### Use the sparkUDF function to generate some inferences from full dataset for evals 

In [0]:
import mlflow
import json
from pyspark.sql import functions as f, types as t
# from pyspark.sql import Row

## use sample_df
# If sample_df is a Pandas DataFrame, convert it to Spark DataFrame
if isinstance(sample_df, pd.DataFrame):
    df = spark.createDataFrame(sample_df)
else:
    df = sample_df

uc_model_uri = f"models:/{uc_model_name}/{get_latest_uc_model_version(uc_model_name)}"

# Load model as a Spark UDF. Override result_type if the model does not return double values.
model_spark_udf = mlflow.pyfunc.spark_udf(
    spark,
    model_uri=uc_model_uri
)

# Predict on a Spark DataFrame.
df = df.withColumn(
    'predictions',
    model_spark_udf() ## with model signature this is simplified 
    )
# display(df)

## ---------------------------------------------------------------------------------------------

## with predictions column -- we can parse out the nested json as before -- extract schema from first row:
json_rdd = df.limit(1).select("predictions").rdd.map(lambda row: row.predictions)
inferred_df = spark.read.json(json_rdd)
# display(inferred_df)

## apply schema in from_json extraction to all rows
parsed_df = df.withColumn("parsed",
                          f.from_json(f.col("predictions"), inferred_df.schema)
                         )
parsed_df = parsed_df.select("*", "parsed.*")                                    
display(parsed_df)

In [0]:
# parsed_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.<table_name e.g. sample_predictions_parsed>")

### DOWNSTREAM Custom Evals 

In [0]:
# eval_df = spark.table("mmt.bricks.tbct_sample_predictions_parsed")
eval_df = spark.table(f"{catalog_name}.{schema_name}.<table_name e.g. sample_predictions_parsed>")
display(eval_df)

In [0]:
# eval_data = [
#     {
#         "inputs": {
#             "workorder_id": row["workorder_id"],
#             "workorder_notes": row["workorder_notes"]
#         },
#         "outputs": {
#             "predictions": row["predictions"]
#         }
#     }
#     for row in eval_df.select(
#         "workorder_id",
#         "workorder_notes",
#         "predictions"
#     ).collect()
# ]
# display(eval_data)

In [0]:
# --------------------------------------------------------------
#  Imports & helper utilities
# --------------------------------------------------------------
from mlflow.genai.scorers import scorer, Guidelines
from mlflow.genai.judges import meets_guidelines
import mlflow
import json
from typing import Dict, List

# --------------------------------------------------------------
#   Helper Function to fetch the latest production version of a registered model
# --------------------------------------------------------------

## use existing get_latest_uc_model_version defined above
def get_latest_uc_model_version(uc_model_name):
    from mlflow.tracking import MlflowClient

    client = MlflowClient()
    model_versions = client.search_model_versions(f"name='{uc_model_name}'")
    latest_uc_model_version = max([int(mv.version) for mv in model_versions])
    print(f"Latest version for {uc_model_name}: {latest_uc_model_version}")

    return latest_uc_model_version

# latest_uc_model_version = get_latest_uc_model_version(uc_model_name)  



# --------------------------------------------------------------
#  Load the UC-registered Agent Brick KIE model
# --------------------------------------------------------------
UC_MODEL_NAME = "mmt.bricks.tbct_wrapped_KIEagent"
UC_MODEL_URI = f"models:/{UC_MODEL_NAME}/{get_latest_uc_model_version(UC_MODEL_NAME)}"
wrapped_KIEagent = mlflow.pyfunc.load_model(model_uri=UC_MODEL_URI)

# --------------------------------------------------------------
#  Predict wrapper – traced, returns JSON string predictions
# --------------------------------------------------------------

@mlflow.trace
def extract_workorder_notes(
    workorder_id: str,
    workorder_notes: str
) -> Dict:
    payload = {
        "workorder_id": workorder_id,
        "workorder_notes": workorder_notes
    }
    response = wrapped_KIEagent.predict(payload)
    print(response)

    raw_predictions = response
    if hasattr(raw_predictions, "tolist"):
        raw_predictions = raw_predictions.tolist()

    return {"predictions": json.dumps(raw_predictions)}

# --------------------------------------------------------------
#  Prepare evaluation data – **flattened** 
# --------------------------------------------------------------
# Every element MUST have an `"inputs"` key | `"outputs"`could be optional (?) – 
# MLflow will populate it after calling `predict_fn`.

eval_data = [
    {
        "inputs": {
            "workorder_id": row["workorder_id"],
            "workorder_notes": row["workorder_notes"]
        },
        "outputs": {
            "predictions": row["predictions"]
        }
    }
    for row in eval_df.select(
        "workorder_id",
        "workorder_notes",
        "predictions"
    ).collect()
]

# eval_data ## defined earlier

### Define Judges & Scorers

- https://docs.databricks.com/aws/en/mlflow3/genai/eval-monitor/custom-judge/ 


#### Test a custom eval with pre-built `meets_guidelines` Judge: e.g. `Hallucination Guardrail` 

In [0]:
# --------------------------------------------------------------
#  Scorers – single‑arg dict signature (do NOT call it!)
# --------------------------------------------------------------

@scorer
def hallucination_guardrail(inputs, outputs): ## name in such a way that True/Pass vs False/Fail is clear
    """
    `example` is created by mlflow.genai.evaluate and looks like:
    {
        "inputs":  {...},          # the row you fed to predict_fn
        "outputs": {"predictions": "..."}   # whatever your predict_fn returned
    }
    """

    guidelines=[
        # "The extracted information must not contain hallucinated details.",
        # "The response must be based solely on the provided workorder notes."
        "The extracted information can include inferred insights but must be factually consistent with the provided context.",        
    ]

    
    # Run the built‑in guideline judge
    return meets_guidelines(
        name="hallucination_guardrail",
        guidelines=guidelines,
        context={
            "request": inputs["workorder_notes"],
            "response": json.loads(outputs["predictions"])
        },
    )


In [0]:
# --------------------------------------------------------------
#  Run the evaluation
# --------------------------------------------------------------
# NOTE: The *list* of scorers must contain the callable, NOT a call.
results = mlflow.genai.evaluate(
    data=eval_data,
    predict_fn=extract_workorder_notes,
    scorers=[hallucination_guardrail, # <-- no parentheses!
             ## add other svorers
            ]   
)

# Show the result in the notebook (Databricks UI also gives a nice table)
display(results)

In [0]:
# results.save("/Volumes/mmt/bricks/tbct_predictive_maintenance/workorder_eval_results_20250821")

In [0]:
!ls /Volumes/mmt/bricks/tbct_predictive_maintenance


### To Include ADDITIONAL/ OTHER Custom/Prebuilt Metrics etc.


- https://docs.databricks.com/aws/en/mlflow3/genai/eval-monitor/concepts/scorers
- https://docs.databricks.com/aws/en/mlflow3/genai/eval-monitor/custom-scorers 


%md

#### Test a custom eval with pre-built `LLM judges` 


The is_* judges in mlflow.genai.judges are specialized functions that score different quality dimensions of LLM outputs. Each judge evaluates a specific aspect of the model's response or context, such as relevance, safety, grounding, correctness, or sufficiency. Here is what each judge scores

- **`is_safe`**: Scores whether the content contains harmful, offensive, or toxic material.
- **`is_relevant / is_context_relevant`**: Scores whether the context or response is directly relevant to the user's request, without deviating into unrelated topics.
- **`is_grounded`**: Scores whether the response is grounded in the information provided in the context (i.e., not hallucinated).
- **`is_correct`**: Scores whether the response is factually correct compared to provided ground truth.
- **`is_context_sufficient`**: Scores whether the context provides all necessary information to generate a response that includes the ground truth for the given request.   

These judges are used to automate and standardize the evaluation of LLM outputs for different use cases and can be wrapped in custom scorers for integration with MLflow's evaluation harness

- https://learn.microsoft.com/en-us/azure/databricks/mlflow3/genai/eval-monitor/concepts/judges/
- https://learn.microsoft.com/en-us/azure/databricks/mlflow3/genai/eval-monitor/concepts/judges/pre-built-judges-scorers

In [0]:
#REF https://docs.databricks.com/aws/en/mlflow3/genai/eval-monitor/custom-scorers#example-2-wrap-a-predefined-llm-judge

from typing import Any
from mlflow.genai.judges import is_context_relevant, is_grounded
from mlflow.genai.scorers import scorer

@scorer
def is_response_relevant(
    inputs: dict[str, Any],
    outputs: dict
) -> dict:
    user_query = inputs.get("workorder_notes")
    agent_response = outputs.get("predictions")  # Use 'predictions' instead of 'answer'

    if not user_query or not agent_response:
        raise Exception("Missing input fields: response or workorder_notes.")

    return is_context_relevant(
        request=user_query,
        context={"response": agent_response},
    )

@scorer
def is_grounded_scorer(
    inputs: dict[str, Any],
    outputs: dict
) -> dict:
    user_query = inputs.get("workorder_notes")
    agent_response = outputs.get("predictions")  # Use 'predictions' instead of 'answer'

    if not user_query or not agent_response:
        raise Exception("Missing input fields: response or workorder_notes.")

    return is_grounded(
        request=user_query,
        response=agent_response,
        context={"response": agent_response}
    )



predefinedllmjudge_scorer_eval_results = mlflow.genai.evaluate(
                                                                data=eval_data,
                                                                scorers=[is_response_relevant, 
                                                                        is_grounded_scorer]
                                                              )


#### Test a custom eval with custom-prompt-based-judges

In [0]:
#REF https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/evaluate-prompts#custom-prompt-based-judge

from mlflow.genai.judges import custom_prompt_judge
from mlflow.genai.scorers import scorer

# Create a custom prompt judge for device replacement
device_replacement_judge = custom_prompt_judge(
    name="device_repair_and_replacement_compliance_with_reasoning",
    prompt_template="""Evaluate if the following work order indicates a device repair and replacement of a new part is ambiguous.

Workorder Notes: {{workorder_notes}}
Predicted Action: {{part_replace_per_protocol}}

Choose the appropriate rating:
[[repaired and replaced]]: Device replacement of a new part took place
[[not_replaced]]: Device replacement of a new part did NOT take place

Explain your reasoning for the rating:""",
    numeric_values={
        "repaired and replaced": 1.0,
        "not_replaced": 0.0
    }
)

# Wrap the judge in a scorer
@scorer
def device_replacement_scorer(inputs, outputs, trace) -> bool:
    """Custom scorer that evaluates device replacement compliance."""
    workorder_notes = inputs.get("workorder_notes", "")
    part_replace = outputs.get("part_replace_per_protocol", "")
    result = device_replacement_judge(
        workorder_notes=workorder_notes,
        part_replace_per_protocol=part_replace
    )
    return result.value == 1.0  # True if repaired and replaced, False otherwise


custompromptjudge_scorer_eval_results = mlflow.genai.evaluate(
                                                                data=eval_data,
                                                                scorers=[device_replacement_scorer]
                                                              )

In [0]:
#REF https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/evaluate-prompts#custom-prompt-based-judge

from mlflow.genai.judges import custom_prompt_judge
from mlflow.genai.scorers import scorer

device_replacement_ambiguity_judge = custom_prompt_judge(
    name="device_replacement_ambiguity",
    prompt_template="""Evaluate if the following work order is ambiguous regarding device repair and replacement of a new device part:

Workorder Notes: {{workorder_notes}}
Ambiguity Flag: {{ambiguity_flag}}

Choose the appropriate rating:
[[ambiguous]]: The work order is ambiguous about device repair and replacement.
[[not_ambiguous]]: The work order is NOT ambiguous about device repair and replacement.""",
    numeric_values={
        "ambiguous": 1.0,
        "not_ambiguous": 0.0
    }
)

@scorer
def device_replacement_ambiguity_scorer(inputs, outputs, trace) -> bool:
    workorder_notes = inputs.get("workorder_notes", "")
    ambiguity_flag = outputs.get("ambiguity_flag", "")
    result = device_replacement_ambiguity_judge(
        workorder_notes=workorder_notes,
        ambiguity_flag=ambiguity_flag
    )
    return result.value == 1.0  # True if ambiguous, False otherwise

custompromptjudge_scorer_eval_results = mlflow.genai.evaluate(
    data=eval_data,
    scorers=[device_replacement_ambiguity_scorer]
)

## Combined Evaluations + Results 

In [0]:
combined_results = mlflow.genai.evaluate(
                                         data=eval_data,
                                         predict_fn=extract_workorder_notes,
                                         scorers=[hallucination_guardrail, 
                                                  is_grounded_scorer,
                                                  is_response_relevant,
                                                  device_replacement_scorer,
                                                  device_replacement_ambiguity_scorer
                                                 ]
                                        )

display(combined_results)                                        

NOTES:

Evaluation with LLM-based judges (such as those in `mlflow.genai.judges` or c`ustom prompt judges`) can be `slow` because each evaluation typically sends a separate API request to the LLM endpoint for every row and every scorer. This means:

If you have 100 rows and 3 scorers, you may be making 300 LLM calls.
Each LLM call can take several seconds, depending on model, load, and network latency.
LLM endpoints may have rate limits, further slowing batch processing.
If you use complex prompt templates or large contexts, the LLM inference time increases.
To speed up evaluation:

Reduce the number of rows or scorers for initial testing.
Use smaller models or endpoints with lower latency.
Batch requests if your framework and endpoint support it.
Cache results for repeated evaluations.
This is a known limitation of LLM-based evaluation workflows.

