-
Notifications
You must be signed in to change notification settings - Fork 4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the ability to log LLM artifacts (#8204)
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
- Loading branch information
1 parent
9cf7198
commit 2bac047
Showing
4 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
The ``mlflow.llm`` module provides a utility for Large Language Models (LLMs). | ||
""" | ||
|
||
from mlflow.tracking.llm_utils import log_predictions | ||
|
||
__all__ = [ | ||
"log_predictions", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import csv | ||
import logging | ||
import mlflow | ||
import os | ||
import tempfile | ||
|
||
from mlflow.tracking.client import MlflowClient | ||
from mlflow.utils.annotations import experimental | ||
from typing import Dict, List, Union | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
@experimental | ||
def log_predictions( | ||
inputs: List[Union[str, Dict[str, str]]], | ||
outputs: List[str], | ||
prompts: List[Union[str, Dict[str, str]]], | ||
) -> None: | ||
""" | ||
Log a batch of inputs, outputs and prompts for the current evaluation run. | ||
If no run is active, this method will create a new active run. | ||
:param inputs: Union of either List of input strings or List of input dictionary | ||
:param outputs: List of output strings | ||
:param prompts: Union of either List of prompt strings or List of prompt dictionary | ||
:returns: None | ||
.. test-code-block:: python | ||
:caption: Example | ||
import mlflow | ||
inputs = [ | ||
{ | ||
"question": "How do I create a Databricks cluster with UC access?", | ||
"context": "Databricks clusters are ...", | ||
}, | ||
] | ||
outputs = [ | ||
"<Instructions for cluster creation with UC enabled>", | ||
] | ||
prompts = [ | ||
"Get Databricks documentation to answer all the questions: {input}", | ||
] | ||
# Log llm predictions | ||
with mlflow.start_run(): | ||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
""" | ||
if len(inputs) <= 0 or len(inputs) != len(outputs) or len(inputs) != len(prompts): | ||
raise ValueError( | ||
"The length of inputs, outputs and prompts must be the same and not empty." | ||
) | ||
|
||
artifact_path = None | ||
predictions = [] | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
LLM_ARTIFACT_NAME = "llm_predictions.csv" | ||
|
||
for row in zip(inputs, outputs, prompts): | ||
predictions.append(row) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
artifacts = [f.path for f in MlflowClient().list_artifacts(run_id)] | ||
if LLM_ARTIFACT_NAME in artifacts: | ||
artifact_path = mlflow.artifacts.download_artifacts( | ||
run_id=run_id, artifact_path=LLM_ARTIFACT_NAME, dst_path=tmpdir | ||
) | ||
_logger.info( | ||
"Appending new inputs to already existing artifact " | ||
f"{LLM_ARTIFACT_NAME} for run {run_id}." | ||
) | ||
else: | ||
# If the artifact doesn't exist, we need to write the header. | ||
predictions.insert(0, ["inputs", "outputs", "prompts"]) | ||
artifact_path = os.path.join(tmpdir, LLM_ARTIFACT_NAME) | ||
_logger.info(f"Creating a new {LLM_ARTIFACT_NAME} for run {run_id}.") | ||
|
||
if os.path.exists(artifact_path): | ||
with open(artifact_path, newline="") as llm_prediction: | ||
num_existing_predictions = sum(1 for _ in csv.reader(llm_prediction)) | ||
if num_existing_predictions + len(predictions) > 1000: | ||
_logger.warning( | ||
f"Trying to log a {LLM_ARTIFACT_NAME} with length " | ||
"more than 1000 records. It might slow down performance." | ||
) | ||
|
||
with open(artifact_path, "a", encoding="UTF8", newline="") as llm_prediction: | ||
writer = csv.writer(llm_prediction) | ||
writer.writerows(predictions) | ||
mlflow.tracking.fluent.log_artifact(artifact_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import mlflow | ||
import pytest | ||
|
||
from mlflow.utils.file_utils import local_file_uri_to_path | ||
|
||
|
||
def test_llm_predictions_logging(): | ||
import csv | ||
|
||
inputs = [ | ||
{ | ||
"question": "How do I create a Databricks cluster with UC enabled?", | ||
"context": "Databricks clusters are amazing", | ||
} | ||
] | ||
|
||
outputs = [ | ||
"<Instructions for cluster creation with UC enabled>", | ||
] | ||
|
||
prompts = [ | ||
"Get Databricks documentation to answer all the questions: {input}", | ||
] | ||
|
||
artifact_file_name = "llm_predictions.csv" | ||
with mlflow.start_run(): | ||
with pytest.raises( | ||
ValueError, | ||
match="The length of inputs, outputs and prompts must be the same and not empty.", | ||
): | ||
mlflow.llm.log_predictions([], [], []) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="The length of inputs, outputs and prompts must be the same and not empty.", | ||
): | ||
mlflow.llm.log_predictions( | ||
[], ["<Instructions for cluster creation with UC enabled>"], [] | ||
) | ||
|
||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
artifact_path = local_file_uri_to_path(mlflow.get_artifact_uri(artifact_file_name)) | ||
|
||
with open(artifact_path, newline="") as csvfile: | ||
predictions = list(csv.reader(csvfile)) | ||
|
||
# length of header + length of inputs | ||
assert len(predictions) == 2 | ||
assert predictions[1][0] == str(inputs[0]) | ||
assert predictions[1][1] == outputs[0] | ||
assert predictions[1][2] == prompts[0] | ||
|
||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
|
||
with open(artifact_path, newline="") as csvfile: | ||
predictions = list(csv.reader(csvfile)) | ||
|
||
assert len(predictions) == 3 |