Skip to content

Commit

Permalink
Resolving comments
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed Apr 24, 2023
1 parent 5cedf47 commit 8658b64
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 53 deletions.
7 changes: 6 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1779,11 +1779,16 @@ interpreted as a generic Python function for inference via :py:func:`mlflow.pyfu
You can also use the :py:func:`mlflow.langchain.load_model()` function to load a saved or logged MLflow
Model with the ``langchain`` flavor as a dictionary of the model's attributes.

Example:
Example: Log a LangChain LLMChain

.. literalinclude:: ../../examples/langchain/simple_chain.py
:language: python

Example: Log a LangChain Agent

.. literalinclude:: ../../examples/langchain/simple_agent.py
:language: python

Diviner (``diviner``)
^^^^^^^^^^^^^^^^^^^^^
The ``diviner`` model flavor enables logging of
Expand Down
93 changes: 54 additions & 39 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
_MODEL_DATA_KEY = "model_data"
_AGENT_PRIMITIVES_FILE_NAME = "agent_primitive_args.json"
_AGENT_PRIMITIVES_DATA_KEY = "agent_primitive_data"
_AGENT_DATA_FILE_NAME = "agent.json"
_AGENT_DATA_FILE_NAME = "agent.yaml"
_AGENT_DATA_KEY = "agent_data"
_TOOLS_DATA_FILE_NAME = "tools.pkl"
_TOOLS_DATA_KEY = "tools_data"
Expand Down Expand Up @@ -291,27 +291,35 @@ def log_model(
"""
import langchain

if (
type(lc_model) != langchain.chains.llm.LLMChain
and type(lc_model) != langchain.agents.agent.AgentExecutor
if not isinstance(
lc_model, (langchain.chains.llm.LLMChain, langchain.agents.agent.AgentExecutor)
):
raise TypeError(
raise mlflow.MlflowException.invalid_parameter_value(
"MLflow langchain flavor only supports logging langchain.chains.llm.LLMChain and "
+ f"langchain.agents.agent.AgentExecutor instances, found {type(lc_model)}"
)

_SUPPORTED_LLMS = {langchain.llms.openai.OpenAI, langchain.llms.huggingface_hub.HuggingFaceHub}
if (
type(lc_model) == langchain.chains.llm.LLMChain
isinstance(lc_model, langchain.chains.llm.LLMChain)
and type(lc_model.llm) not in _SUPPORTED_LLMS
) or (
type(lc_model) == langchain.agents.agent.AgentExecutor
):
logger.warning(
"MLflow does not guarantee support for LLMChains outside of HuggingFaceHub and "
"OpenAI, found %s",
type(lc_model.llm).__name__,
)

if (
isinstance(lc_model, langchain.agents.agent.AgentExecutor)
and type(lc_model.agent.llm_chain.llm) not in _SUPPORTED_LLMS
):
logger.warning(
"MLflow does not guarantee support for LLMChains outside of HuggingFaceHub and "
"OpenAI, found %s",
str(type(lc_model.llm)),
type(lc_model.agent.llm_chain.llm).__name__,
)

return Model.log(
artifact_path=artifact_path,
flavor=mlflow.langchain,
Expand Down Expand Up @@ -343,31 +351,31 @@ def _save_model(model, path):
if model.agent:
agent_data_path = os.path.join(path, _AGENT_DATA_FILE_NAME)
model.save_agent(agent_data_path)
model_data_kwargs = {**model_data_kwargs, _AGENT_DATA_KEY: _AGENT_DATA_FILE_NAME}
model_data_kwargs[_AGENT_DATA_KEY] = _AGENT_DATA_FILE_NAME

if model.tools:
tools_data_path = os.path.join(path, _TOOLS_DATA_FILE_NAME)
with open(tools_data_path, "wb") as f:
cloudpickle.dump(model.tools, f)
model_data_kwargs = {**model_data_kwargs, _TOOLS_DATA_KEY: _TOOLS_DATA_FILE_NAME}
model_data_kwargs[_TOOLS_DATA_KEY] = _TOOLS_DATA_FILE_NAME
else:
raise mlflow.MlflowException.invalid_parameter_value(
"For initializing the AgentExecutor, tools must be provided."
)

key_to_ignore = ["llm_chain", "agent", "tools", "callback_manager"]
temp_dict = {}
for k, v in model.__dict__.items():
if k not in key_to_ignore:
temp_dict[k] = v
temp_dict = {k: v for k, v in model.__dict__.items() if k not in key_to_ignore}

agent_primitive_path = os.path.join(path, _AGENT_PRIMITIVES_FILE_NAME)
with open(agent_primitive_path, "w") as config_file:
json.dump(temp_dict, config_file, indent=4)

model_data_kwargs = {
**model_data_kwargs,
_AGENT_PRIMITIVES_DATA_KEY: _AGENT_PRIMITIVES_FILE_NAME,
}
model_data_kwargs[_AGENT_PRIMITIVES_DATA_KEY] = _AGENT_PRIMITIVES_FILE_NAME
else:
logger.error("Could not save model.")
pass
raise mlflow.MlflowException.invalid_parameter_value(
"MLflow langchain flavor only supports logging langchain.chains.llm.LLMChain and "
+ f"langchain.agents.agent.AgentExecutor instances, found {type(model)}"
)

return model_data_kwargs

Expand All @@ -383,12 +391,22 @@ def _load_model(path, agent_path=None, tools_path=None, agent_primitive_path=Non
from langchain.agents import initialize_agent

llm = load_chain(path)
with open(tools_path, "rb") as f:
tools = cloudpickle.load(f)
with open(agent_primitive_path, "r") as config_file:
args = json.load(config_file)
tools = []
kwargs = {}

if os.path.exists(tools_path):
with open(tools_path, "rb") as f:
tools = cloudpickle.load(f)
else:
raise mlflow.MlflowException(
"Missing file for tools which is required to build the AgentExecutor object."
)

model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **args)
if os.path.exists(agent_primitive_path):
with open(agent_primitive_path, "r") as config_file:
kwargs = json.load(config_file)

model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **kwargs)
return model


Expand Down Expand Up @@ -439,19 +457,16 @@ def _load_model_from_local_fs(local_model_path):
lc_model_path = os.path.join(
local_model_path, flavor_conf.get(_MODEL_DATA_KEY, _MODEL_DATA_FILE_NAME)
)
agent_model_path = None
if flavor_conf.get(_AGENT_DATA_KEY):
agent_model_path = os.path.join(local_model_path, flavor_conf.get(_AGENT_DATA_KEY))

tools_model_path = None
if flavor_conf.get(_TOOLS_DATA_KEY):
tools_model_path = os.path.join(local_model_path, flavor_conf.get(_TOOLS_DATA_KEY))

agent_primitive_path = None
if flavor_conf.get(_AGENT_PRIMITIVES_DATA_KEY):
agent_primitive_path = os.path.join(
local_model_path, flavor_conf.get(_AGENT_PRIMITIVES_DATA_KEY)
)

agent_model_path = tools_model_path = agent_primitive_path = None
if agent_path := flavor_conf.get(_AGENT_DATA_KEY):
agent_model_path = os.path.join(local_model_path, agent_path)

if tools_path := flavor_conf.get(_TOOLS_DATA_KEY):
tools_model_path = os.path.join(local_model_path, tools_path)

if primitive_path := flavor_conf.get(_AGENT_PRIMITIVES_DATA_KEY):
agent_primitive_path = os.path.join(local_model_path, primitive_path)

return _load_model(lc_model_path, agent_model_path, tools_model_path, agent_primitive_path)

Expand Down
2 changes: 1 addition & 1 deletion mlflow/ml-package-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ langchain:
install_dev: |
pip install git+https://github.com/hwchase17/langchain
models:
minimum: "0.0.139"
minimum: "0.0.140"
maximum: "0.0.143"
requirements:
">= 0.0.0": ["pyspark", "transformers", "tensorflow", "openai", "google-search-results"]
Expand Down
2 changes: 2 additions & 0 deletions mlflow/openai/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
import mlflow
import requests
from unittest import mock
from contextlib import contextmanager
Expand All @@ -13,6 +14,7 @@ def __init__(self, status_code, json_data):
self.status_code = status_code
self.content = json.dumps(json_data).encode()
self.headers = {"Content-Type": "application/json"}
self.text = mlflow.__version__


def _chat_completion_json_sample(content):
Expand Down
41 changes: 29 additions & 12 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mlflow
import pytest
import transformers
import json

from contextlib import contextmanager
from langchain.prompts import PromptTemplate
Expand All @@ -12,6 +13,7 @@
from langchain.chains.base import Chain
from pyspark.sql import SparkSession
from typing import Any, List, Mapping, Optional, Dict
from tests.helper_functions import pyfunc_serve_and_score_model

from mlflow.openai.utils import (
_mock_chat_completion_response,
Expand Down Expand Up @@ -233,21 +235,36 @@ def test_langchain_agent_model_predict():
],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}
model = create_model("openaiagent")
with mlflow.start_run():
logged_model = mlflow.langchain.log_model(model, "langchain_model")
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
langchain_input = {
"input": "What was the high temperature in SF yesterday in Fahrenheit? "
"What is that number raised to the .023 power?"
}
with _mock_request(return_value=_MockResponse(200, langchain_agent_output)):
model = create_model("openaiagent")
with mlflow.start_run():
logged_model = mlflow.langchain.log_model(model, "langchain_model")
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
result = loaded_model.predict(
[
{
"input": "What was the high temperature in SF yesterday in Fahrenheit? "
"What is that number raised to the .023 power?"
}
]
)
result = loaded_model.predict([langchain_input])
assert result == [TEST_CONTENT]

inference_payload = json.dumps({"inputs": langchain_input})
langchain_agent_output_serving = {"predictions": langchain_agent_output}
with _mock_request(return_value=_MockResponse(200, langchain_agent_output_serving)):
import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
from mlflow.deployments import PredictionsResponse

response = pyfunc_serve_and_score_model(
logged_model.model_uri,
data=inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

assert (
PredictionsResponse.from_json(response.content.decode("utf-8"))
== langchain_agent_output_serving
)


def test_unsupported_chain_types():
chain = FakeChain()
Expand Down

0 comments on commit 8658b64

Please sign in to comment.