# UC Model Evaluation and Promotion (`mlflow.genai.scorers`)

Evaluate latest registered UC model version using LLM scorers
(RelevanceToQuery, Safety) and promote to 'champion' if metrics
exceed threshold.


In [0]:
# Install required dependencies for MLflow model loading, including databricks-langchain and langchain
%pip install -U -qqqq mlflow-skinny[databricks] databricks-langchain langchain psycopg[binary]

dbutils.library.restartPython()

In [0]:
# Widgets for catalog, schema, and model base name

dbutils.widgets.text("catalog", "mmt", "Catalog")
dbutils.widgets.text("schema", "LS_agent", "Schema")
dbutils.widgets.text("model_base_name", "lifesciences_agent", "Model Base Name")

dbutils.widgets.text("promotion_threshold", "0.7", "Promotion Threshold")
dbutils.widgets.dropdown(
    "evaluation_metric",
    "relevance",
    ["relevance", "safety", "accuracy"],
    "Evaluation Metric",
)

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
model_base_name = dbutils.widgets.get("model_base_name")

# Construct Fully Qualified Unity Catalog model name
UC_MODEL_NAME = f"{catalog}.{schema}.{model_base_name}"

print(f"Using catalog: {catalog}, schema: {schema}")
print(f"Model base name: {model_base_name}")
print(f"FQ UC Model Name: {UC_MODEL_NAME}")

promotion_threshold = float(dbutils.widgets.get("promotion_threshold"))
evaluation_metric = dbutils.widgets.get("evaluation_metric")

print(f"Threshold: {promotion_threshold}")
print(f"Metric: {evaluation_metric}")


In [0]:
# model_uri = f"models:/{model_name}/{evaluator.get_latest_version()}"
# m = mlflow.pyfunc.load_model(model_uri)
# print(m.metadata.signature)

In [0]:
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import time
import json

import mlflow
from mlflow.tracking import MlflowClient
from mlflow.genai import evaluate as genai_evaluate
from mlflow.genai.scorers import RelevanceToQuery, Safety


@dataclass
class EvaluationResult:
    version: str
    metrics: Dict[str, float]
    passed: bool
    message: str


class ModelEvaluator(ABC):
    def __init__(
        self,
        model_name: str,
        catalog: str,
        schema: str,
        promotion_threshold: float = 0.7,
        evaluation_metric: str = "relevance",
    ):
        self.model_name = model_name
        self.catalog = catalog
        self.schema = schema
        self.promotion_threshold = promotion_threshold
        self.evaluation_metric = evaluation_metric
        self.client = MlflowClient()
        mlflow.set_registry_uri("databricks-uc")

    def get_latest_version(self) -> str:
        versions = self.client.search_model_versions(f"name='{self.model_name}'")
        if not versions:
            raise ValueError(f"No versions found for {self.model_name}")
        return max(versions, key=lambda v: int(v.version)).version

    def get_champion_version(self) -> Optional[str]:
        try:
            champion = self.client.get_model_version_by_alias(
                self.model_name, "champion"
            )
            return champion.version
        except Exception:
            return None

    @abstractmethod
    def create_evaluation_dataset(self) -> List[Dict[str, Any]]:
        pass

    @abstractmethod
    def evaluate_model(self, model_uri: str) -> Dict[str, float]:
        pass

    def promote_to_champion(self, version: str, metrics: Dict[str, float]) -> None:
        metrics_str = ", ".join([f"{k}={v:.3f}" for k, v in metrics.items()])
        comment = f"Promoted by evaluation notebook. Metrics: {metrics_str}"

        self.client.set_registered_model_alias(self.model_name, "champion", version)

        self.client.update_model_version(
            name=self.model_name,
            version=version,
            description=f"{comment}\nPromoted at {time.strftime('%Y-%m-%d %H:%M:%S')}",
        )

        print(f"Promoted version {version} to champion")

    def run(self) -> EvaluationResult:
        latest_version = self.get_latest_version()
        print(f"Evaluating version {latest_version}")

        model_uri = f"models:/{self.model_name}/{latest_version}"

        metrics = self.evaluate_model(model_uri)
        print(f"\nEvaluation metrics: {metrics}")

        primary_score = metrics.get(self.evaluation_metric, 0.0)
        passed = primary_score >= self.promotion_threshold

        if passed:
            champion_version = self.get_champion_version()
            if champion_version and champion_version == latest_version:
                message = f"Version {latest_version} is already champion"
            else:
                self.promote_to_champion(latest_version, metrics)
                message = f"Version {latest_version} promoted to champion"
        else:
            message = (
                f"Version {latest_version} did not meet threshold "
                f"({primary_score:.3f} < {self.promotion_threshold})"
            )

        return EvaluationResult(
            version=latest_version, metrics=metrics, passed=passed, message=message
        )


class LifeSciencesEvaluator(ModelEvaluator):
    """Evaluate UC-registered Responses Agent model with mlflow.genai.evaluate."""

    def create_evaluation_dataset(self) -> List[Dict[str, Any]]:
        """
        Create simple evaluation dataset for mlflow.genai.evaluate.

        Each sample has:
        - 'inputs': dict, passed to predict_fn
        - ict, used by scorers (e.g., expected_respons)
        """
        return [
            {
                "inputs": {
                    "input": [
                        {
                            "role": "user",
                            "content": (
                                f"How many genes are in the "
                                f"{self.catalog}.{self.schema}.genes_knowledge table?"
                            ),
                        }
                    ]
                },
                "expected_response": "genes_knowledge table",
            },
            {
                "inputs": {
                    "input": [
                        {
                            "role": "user",
                            "content": "What is the average confidence for proteins?",
                        }
                    ]
                },
                "expected_response": "average confidence proteins",
            },
            {
                "inputs": {
                    "input": [
                        {"role": "user", "content": "Tell me about kinase proteins."}
                    ]
                },
                "expected_response": "kinase proteins catalyze phosphorylation",
            },
            {
                "inputs": {
                    "input": [
                        {"role": "user", "content": "What compounds target kinases?"}
                    ]
                },
                "expected_response": "compounds target kinases",
            },
            {
                "inputs": {
                    "input": [
                        {"role": "user", "content": "Explain cell signaling pathways."}
                    ]
                },
                "expected_response": "cell signaling pathways involve cascades",
            },
        ]

    def _predict_uc_model_raw(self, model_uri: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Call the UC Responses Agent model as its serving wrapper expects:
        a dict with key 'input' containing the messages list.

        Even though the underlying pyfunc signature shows input_json/output_json,
        the wrapper hides that and directly accepts {'input': [...]}
        and returns either:
          - {'output': [...]}  (chat messages), or
          - a JSON string / dict we normalize below.
        """
        # inputs is like {"input": [ {...}, ... ]}
        messages = inputs["input"]

        model = mlflow.pyfunc.load_model(model_uri)

        # IMPORTANT: pass a dict, not a DataFrame
        raw = model.predict({"input": messages})

        # Normalize to something shaped like {'output': [...]}
        # Cases:
        #  1. raw is already {'output': [...]}
        #  2. raw is a JSON string
        #  3. raw is some other structure we wrap as text
        if isinstance(raw, dict) and "output" in raw:
            return raw

        if isinstance(raw, str):
            try:
                parsed = json.loads(raw)
                if isinstance(parsed, dict) and "output" in parsed:
                    return parsed
                elif isinstance(parsed, list):
                    return {"output": parsed}
                else:
                    # treat as plain text
                    return {
                        "output": [
                            {
                                "role": "assistant",
                                "content": [
                                    {"type": "output_text", "text": str(parsed)}
                                ],
                            }
                        ]
                    }
            except Exception:
                # not JSON, treat as plain text
                return {
                    "output": [
                        {
                            "role": "assistant",
                            "content": [
                                {"type": "output_text", "text": raw}
                            ],
                        }
                    ]
                }

        # If it's a list, maybe it's already messages
        if isinstance(raw, list):
            return {"output": raw}

        # Fallback: convert to text
        return {
            "output": [
                {
                    "role": "assistant",
                    "content": [
                        {"type": "output_text", "text": str(raw)}
                    ],
                }
            ]
        }

    def _extract_text_from_prediction(self, prediction: Any) -> str:
        if isinstance(prediction, list) and prediction:
            prediction = prediction[0]

        if not isinstance(prediction, dict):
            return "" if prediction is None else str(prediction)

        output = prediction.get("output")
        if not isinstance(output, list) or not output:
            return ""

        # Look for assistant messages with structured content
        for msg in reversed(output):
            if (
                isinstance(msg, dict)
                and msg.get("role") == "assistant"
                and isinstance(msg.get("content"), list)
            ):
                texts = [
                    c.get("text", "")
                    for c in msg["content"]
                    if isinstance(c, dict)
                    and c.get("type") in ("output_text", "text")
                ]
                texts = [t for t in texts if t]
                if texts:
                    return "\n".join(texts)

        return ""

    def evaluate_model(self, model_uri: str) -> Dict[str, float]:
        raw_dataset = self.create_evaluation_dataset()
        print(f"Running mlflow.genai.evaluate on {len(raw_dataset)} examples")

        # Outer "inputs" is for genai.evaluate; inner is passed to predict_fn_fn
        eval_dataset = [
            {
                "inputs": {"inputs": row["inputs"]},
                "expected_response": row["expected_response"],
            }
            for row in raw_dataset
        ]

        def predict_fn_fn(inputs: Dict[str, Any]) -> str:
            # inputs is row["inputs"], e.g. {"input": [ {role, content}, ... ]}
            raw_pred = self._predict_uc_model_raw(model_uri, inputs)
            return self._extract_text_from_prediction(raw_pred)

        eval_result = genai_evaluate(
            data=eval_dataset,
            predict_fn=predict_fn_fn,
            scorers=[RelevanceToQuery(), Safety()],
        )

        print("\nRaw metrics from mlflow.genai.evaluate:")
        for k, v in eval_result.metrics.items():
            print(f"  {k}: {v}")

        relevance = float(eval_result.metrics.get("relevance_to_query/mean", 0.0))
        safety = float(eval_result.metrics.get("safety/mean", 1.0))

        metrics = {
            "relevance": relevance,
            "accuracy": relevance,
            "safety": safety,
        }

        print("\nMapped promotion metrics:", metrics)
        return metrics


## not actually using for the workflow example - theoretically works with the same approach
class LifeSciencesGenieEvaluator(LifeSciencesEvaluator):
    def create_evaluation_dataset(self) -> List[Dict[str, Any]]:
        base = super().create_evaluation_dataset()
        genie = [
            {
                "inputs": {
                    "input": [
                        {
                            "role": "user",
                            "content": "Compare average confidence scores between genes and proteins.",
                        }
                    ]
                },
                "expected_response": "A comparison of average confidence scores between genes and proteins.",
            },
            {
                "inputs": {
                    "input": [
                        {
                            "role": "user",
                            "content": "Show me top compounds with highest confidence scores.",
                        }
                    ]
                },
                "expected_response": "A list or table of top compounds ordered by confidence score.",
            },
        ]
        return base + genie

In [0]:
evaluator = LifeSciencesEvaluator(
        model_name=UC_MODEL_NAME,
        catalog=catalog,
        schema=schema,
        promotion_threshold=promotion_threshold,
        evaluation_metric=evaluation_metric,
    )

result = evaluator.run()

print("\n" + "=" * 60)
print("Evaluation Result (UC ResponsesAgent + mlflow.genai)")
print("=" * 60)
print(f"Version: {result.version}")
print(f"Passed: {result.passed}")
print("\nMetrics:")
for metric, score in result.metrics.items():
    print(f"  {metric}: {score:.3f}")
print(f"\n{result.message}")
print("=" * 60)

In [0]:
# Collect model evaluation and context variables from widgets or code outputs/definitions

# Get from widgets if available
try:
    model_base_name = dbutils.widgets.get("model_base_name")
    catalog = dbutils.widgets.get("catalog")
    schema = dbutils.widgets.get("schema")
    UC_MODEL_NAME = f"{catalog}.{schema}.{model_base_name}"
    promotion_threshold = float(dbutils.widgets.get("promotion_threshold"))
    evaluation_metric = dbutils.widgets.get("evaluation_metric")
except Exception:
    # Fallbacks if widgets are not defined
    model_base_name = globals().get("model_base_name", "")
    catalog = globals().get("catalog", "")
    schema = globals().get("schema", "")
    UC_MODEL_NAME = f"{catalog}.{schema}.{model_base_name}"
    promotion_threshold = globals().get("promotion_threshold", 0.7)
    evaluation_metric = globals().get("evaluation_metric", "relevance")

# Get evaluation results from 'result' object if available
try:
    evaluated_version = result.version
    evaluation_metrics = result.metrics
    passed = result.passed
    model_alias = 'champion' if passed else ''
except Exception:
    evaluated_version = ""
    evaluation_metrics = {}
    passed = False
    model_alias = ''

# Ensure evaluation_metrics is not empty 
if not evaluation_metrics:
    raise ValueError("Evaluation metrics are empty! Please check the evaluation logic and dataset. The evaluation should return a non-empty metrics dictionary")

# Set TaskValues for downstream tasks using Databricks Jobs API | https://docs.databricks.com/aws/en/jobs/task-values

dbutils.jobs.taskValues.set("model_base_name", model_base_name)
dbutils.jobs.taskValues.set("UC_MODEL_NAME", UC_MODEL_NAME)
dbutils.jobs.taskValues.set("catalog", catalog)
dbutils.jobs.taskValues.set("schema", schema)
dbutils.jobs.taskValues.set("promotion_threshold", promotion_threshold)
dbutils.jobs.taskValues.set("evaluation_metric", evaluation_metric)
dbutils.jobs.taskValues.set("evaluated_version", evaluated_version)
dbutils.jobs.taskValues.set("evaluation_metrics", evaluation_metrics)
dbutils.jobs.taskValues.set("passed", passed)
dbutils.jobs.taskValues.set("model_alias", model_alias)

# Print confirmation of what was set
print("TaskValues set for downstream tasks:")
for key in [
    "model_base_name", "UC_MODEL_NAME", "catalog", "schema", "promotion_threshold", "evaluation_metric", 
    "evaluated_version", "evaluation_metrics", "passed", "model_alias"
]:
    print(f"{key}: {locals().get(key)}")

In [0]:
# # Replace 'evaluate_model' with the actual task name of the upstream notebook in your job configuration
# upstream_task = "evaluate_model"

# model_base_name = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="model_base_name", debugValue="lifesciences_agent")
# UC_MODEL_NAME = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="UC_MODEL_NAME", debugValue="mmt.LS_agent.lifesciences_agent")
# catalog = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="catalog", debugValue="mmt")
# schema = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="schema", debugValue="LS_agent")
# promotion_threshold = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="promotion_threshold", debugValue=0.7)
# evaluation_metric = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="evaluation_metric", debugValue="relevance")
# evaluated_version = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="evaluated_version", debugValue="1")
# evaluation_metrics = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="evaluation_metrics", 
#                                                  #debugValue={"relevance": 0.4, "accuracy": 0.4, "safety": 1.0} ## NA
#                                                  debugValue={'relevance': 1.0, 'accuracy': 1.0, 'safety': 1.0} ## champion
#                                                 )
# passed = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="passed", debugValue=True)
# model_alias = dbutils.jobs.taskValues.get(taskKey=upstream_task, key="model_alias", debugValue="champion")

# print("Retrieved TaskValues from upstream task:")
# for key in [
#             "model_base_name", "UC_MODEL_NAME", "catalog", "schema", "promotion_threshold", "evaluation_metric", 
#             "evaluated_version", "evaluation_metrics", "passed", "model_alias"
#            ]:
#     print(f"{key}: {locals().get(key)}")