Skip to content

Commit

Permalink
Adding the ability to log LLM artifacts
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed Apr 12, 2023
1 parent 4728afe commit 2d35413
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from mlflow import pmdarima
from mlflow import diviner
from mlflow import transformers
from mlflow import llm

_model_flavors_supported = [
"catboost",
Expand Down
9 changes: 9 additions & 0 deletions mlflow/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
The ``mlflow.llm`` module provides a utility for Large Language Models (LLMs).
"""

from mlflow.tracking.llm_utils import log_predictions

__all__ = [
"log_predictions",
]
95 changes: 95 additions & 0 deletions mlflow/tracking/llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import csv
import logging
import mlflow
import os
import tempfile

from mlflow.tracking.client import MlflowClient
from mlflow.utils.annotations import experimental
from typing import Dict, List, Union

_logger = logging.getLogger(__name__)


@experimental
def log_predictions(
inputs: List[Union[str, Dict[str, str]]],
outputs: List[str],
prompts: List[Union[str, Dict[str, str]]],
) -> None:
"""
Log a batch of inputs, outputs and prompts for the current evaluation run.
If no run is active, this method will create a new active run.
:param inputs: Union of either List of input strings or List of input dictionary
:param outputs: List of output strings
:param prompts: Union of either List of prompt strings or List of prompt dictionary
:returns: None
.. test-code-block:: python
:caption: Example
import mlflow
inputs = [
{
"question": "How do I create a Databricks cluster with UC access?",
"context": "Databricks clusters are ...",
},
]
outputs = [
"<Instructions for cluster creation with UC enabled>",
]
prompts = [
"Get Databricks documentation to answer all the questions: {input}",
]
# Log llm predictions
with mlflow.start_run():
mlflow.llm.log_predictions(inputs, outputs, prompts)
"""
if len(inputs) <= 0 or len(inputs) != len(outputs) or len(inputs) != len(prompts):
raise ValueError(
"The length of inputs, outputs and prompts must be the same and not empty."
)

artifact_path = None
predictions = []
run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
LLM_ARTIFACT_NAME = "llm_predictions.csv"

for row in zip(inputs, outputs, prompts):
predictions.append(row)

with tempfile.TemporaryDirectory() as tmpdir:
artifacts = [f.path for f in MlflowClient().list_artifacts(run_id)]
if LLM_ARTIFACT_NAME in artifacts:
artifact_path = mlflow.artifacts.download_artifacts(
run_id=run_id, artifact_path=LLM_ARTIFACT_NAME, dst_path=tmpdir
)
_logger.info(
"Appending new inputs to already existing artifact "
f"{LLM_ARTIFACT_NAME} for run {run_id}."
)
else:
# If the artifact doesn't exist, we need to write the header.
predictions.insert(0, ["inputs", "outputs", "prompts"])
artifact_path = os.path.join(tmpdir, LLM_ARTIFACT_NAME)
_logger.info(f"Creating a new {LLM_ARTIFACT_NAME} for run {run_id}.")

if os.path.exists(artifact_path):
with open(artifact_path, newline="") as llm_prediction:
num_existing_predictions = sum(1 for _ in csv.reader(llm_prediction))
if num_existing_predictions + len(predictions) > 1000:
_logger.warning(
f"Trying to log a {LLM_ARTIFACT_NAME} with length "
"more than 1000 records. It might slow down performance."
)

with open(artifact_path, "a", encoding="UTF8", newline="") as llm_prediction:
writer = csv.writer(llm_prediction)
writer.writerows(predictions)
mlflow.tracking.fluent.log_artifact(artifact_path)
58 changes: 58 additions & 0 deletions tests/tracking/test_llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import mlflow
import pytest

from mlflow.utils.file_utils import local_file_uri_to_path


def test_llm_predictions_logging():
import csv

inputs = [
{
"question": "How do I create a Databricks cluster with UC enabled?",
"context": "Databricks clusters are amazing",
}
]

outputs = [
"<Instructions for cluster creation with UC enabled>",
]

prompts = [
"Get Databricks documentation to answer all the questions: {input}",
]

artifact_file_name = "llm_predictions.csv"
with mlflow.start_run():
with pytest.raises(
ValueError,
match="The length of inputs, outputs and prompts must be the same and not empty.",
):
mlflow.llm.log_predictions([], [], [])

with pytest.raises(
ValueError,
match="The length of inputs, outputs and prompts must be the same and not empty.",
):
mlflow.llm.log_predictions(
[], ["<Instructions for cluster creation with UC enabled>"], []
)

mlflow.llm.log_predictions(inputs, outputs, prompts)
artifact_path = local_file_uri_to_path(mlflow.get_artifact_uri(artifact_file_name))

with open(artifact_path, newline="") as csvfile:
predictions = list(csv.reader(csvfile))

# length of header + length of inputs
assert len(predictions) == 2
assert predictions[1][0] == str(inputs[0])
assert predictions[1][1] == outputs[0]
assert predictions[1][2] == prompts[0]

mlflow.llm.log_predictions(inputs, outputs, prompts)

with open(artifact_path, newline="") as csvfile:
predictions = list(csv.reader(csvfile))

assert len(predictions) == 3

0 comments on commit 2d35413

Please sign in to comment.