# Run Inference


Runs parallel inference against Model Serving endpoints

In [None]:
%pip install -e ./src


In [None]:
%restart_python


In [None]:
import logging
import uuid
from verdict.inference.inference_runner import InferenceRunner
from verdict.data.prompt_dataset import PromptDatasetManager

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Widget parameters
dbutils.widgets.text("model_endpoint", "your-model-endpoint", "Model Endpoint")
dbutils.widgets.text("dataset_version", "v1", "Dataset Version")
dbutils.widgets.text("candidate_version", "", "Candidate Version (auto-detected if empty)")
dbutils.widgets.text("catalog_name", "verdict", "Catalog Name")

model_endpoint = dbutils.widgets.get("model_endpoint")
dataset_version = dbutils.widgets.get("dataset_version")
candidate_version = dbutils.widgets.get("candidate_version") or None
catalog_name = dbutils.widgets.get("catalog_name")

In [None]:
logger.info(f"Starting inference for endpoint: {model_endpoint}")
logger.info(f"Dataset version: {dataset_version}")
logger.info(f"Catalog: {catalog_name}")

In [None]:
# Initialize managers
inference_runner = InferenceRunner(catalog_name=catalog_name)
dataset_manager = PromptDatasetManager(catalog_name=catalog_name)

In [None]:
# Load prompt dataset
prompts_df = dataset_manager.load_dataset(dataset_version)
prompt_count = prompts_df.count()
logger.info(f"Loaded {prompt_count} prompts from dataset version {dataset_version}")

if prompt_count == 0:
    raise ValueError(f"No prompts found in dataset version {dataset_version}")

In [None]:
# Run inference
run_id = str(uuid.uuid4())
results_df = inference_runner.run_inference(
    endpoint_name=model_endpoint,
    prompt_dataset=prompts_df,
    model_version=candidate_version,
    run_id=run_id,
    batch_size=100,
    max_workers=10
)

In [None]:
# Summary statistics
total = results_df.count()
success = results_df.filter("status = 'success'").count()
failed = total - success

logger.info(f"Inference complete: {success}/{total} successful ({failed} failed)")

In [None]:
# Display results
print(f"\nRun ID: {run_id}")
print(f"Model Endpoint: {model_endpoint}")
print(f"Success Rate: {success}/{total} ({100*success/total:.1f}%)")

results_df.select("prompt_id", "response", "latency_ms", "status").display()

In [None]:
# Return run_id for downstream tasks
dbutils.jobs.taskValues.set("run_id", run_id)
dbutils.jobs.taskValues.set("model_endpoint", model_endpoint)
dbutils.jobs.taskValues.set("success_rate", success/total if total > 0 else 0)