-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the ability to log LLM artifacts
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
- Loading branch information
1 parent
93ee352
commit 3042a74
Showing
4 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
The ``mlflow.llm`` module provides a utility for Large Language Models (LLMs). | ||
""" | ||
|
||
from mlflow.tracking.llm_utils import log_predictions | ||
|
||
__all__ = [ | ||
"log_predictions", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import csv | ||
import logging | ||
import mlflow | ||
import os | ||
import shutil | ||
import tempfile | ||
|
||
from typing import Dict, List, Union | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
def log_predictions( | ||
inputs: List[Union[str, Dict[str, str]]], | ||
outputs: List[str], | ||
prompts: List[Union[str, Dict[str, str]]], | ||
) -> None: | ||
""" | ||
Log a batch of inputs, outputs and prompts for the current evaluation run. | ||
If no run is active, this method will create a new active run. | ||
:param inputs: Union of either List of input strings or List of input dictionary | ||
:param outputs: List of output strings | ||
:param prompts: Union of either List of prompt strings or List of prompt dictionary | ||
:returns: None | ||
.. test-code-block:: python | ||
:caption: Example | ||
import mlflow | ||
inputs = [{ | ||
"question": "How do I create a Databricks cluster with | ||
UC access?", | ||
"context": "Databricks clusters are ..." | ||
}, | ||
... | ||
] | ||
outputs = [ | ||
"<Instructions for cluster creation with UC enabled>", | ||
] | ||
prompts = [ | ||
"Get Databricks documentation to answer all the questions: {input}", | ||
] | ||
# Log llm predictions | ||
with mlflow.start_run(): | ||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
""" | ||
if len(inputs) != len(outputs) or len(inputs) != len(prompts): | ||
raise ValueError("The length of inputs, outputs and prompts must be the same.") | ||
|
||
predictions = [["inputs", "outputs", "prompts"]] | ||
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id | ||
LLM_ARTIFACT_NAME = "llm_predictions.csv" | ||
|
||
try: | ||
artifact_path = mlflow.artifacts.download_artifacts( | ||
run_id=run_id, artifact_path=LLM_ARTIFACT_NAME | ||
) | ||
|
||
with open(artifact_path, newline="") as llm_prediction: | ||
predictions = list(csv.reader(llm_prediction)) | ||
|
||
_logger.info( | ||
f"Appending new inputs to already existing artifact {artifact_path} for run {run_id}." | ||
) | ||
except OSError: | ||
_logger.info(f"Creating a new LLM artifact for run {run_id}.") | ||
|
||
for i, _ in enumerate(inputs): | ||
predictions.append([inputs[i], outputs[i], prompts[i]]) | ||
|
||
if len(predictions) > 1000: | ||
_logger.warning( | ||
"Trying to log a LLM artifact with length " | ||
"more than 1000 records. It might slow down performance." | ||
) | ||
|
||
tmpdir = tempfile.mkdtemp() | ||
try: | ||
filepath = os.path.join(tmpdir, LLM_ARTIFACT_NAME) | ||
with open(filepath, "w", encoding="UTF8", newline="") as llm_prediction: | ||
writer = csv.writer(llm_prediction) | ||
writer.writerows(predictions) | ||
mlflow.tracking.fluent.log_artifact(filepath) | ||
finally: | ||
shutil.rmtree(tmpdir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import mlflow | ||
|
||
|
||
def test_llm_predictions_logging(tmpdir): | ||
import csv | ||
|
||
inputs = [ | ||
{ | ||
"question": "How do I create a Databricks cluster with UC enabled?", | ||
"context": "Databricks clusters are amazing", | ||
} | ||
] | ||
|
||
outputs = [ | ||
"<Instructions for cluster creation with UC enabled>", | ||
] | ||
|
||
prompts = [ | ||
"Get Databricks documentation to answer all the questions: {input}", | ||
] | ||
|
||
artifact_file_name = "llm_predictions.csv" | ||
with mlflow.start_run(): | ||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
artifact_uri = mlflow.get_artifact_uri(artifact_file_name) | ||
|
||
with open(artifact_uri, newline="") as csvfile: | ||
predictions = list(csv.reader(csvfile)) | ||
|
||
# length of header + length of inputs | ||
assert len(predictions) == 2 | ||
assert predictions[1][0] == str(inputs[0]) | ||
assert predictions[1][1] == outputs[0] | ||
assert predictions[1][2] == prompts[0] | ||
|
||
mlflow.llm.log_predictions(inputs, outputs, prompts) | ||
|
||
with open(artifact_uri, newline="") as csvfile: | ||
predictions = list(csv.reader(csvfile)) | ||
|
||
assert len(predictions) == 3 |