Skip to content

Commit

Permalink
Fix langchain model saving with dict error message (#11822)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed Apr 25, 2024
1 parent f16f9fb commit 4fd44db
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 11 deletions.
4 changes: 3 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,9 @@ def _save_model(model, path, loader_fn, persist_dir):
"using `pip install cloudpickle>=2.1.0` "
"to ensure the model can be loaded correctly."
)
with register_pydantic_v1_serializer_cm():
# patch_langchain_type_to_cls_dict here as we attempt to load model
# if it's saved by `dict` method
with register_pydantic_v1_serializer_cm(), patch_langchain_type_to_cls_dict():
if isinstance(model, lc_runnables_types()):
return _save_runnables(model, path, loader_fn=loader_fn, persist_dir=persist_dir)
else:
Expand Down
24 changes: 14 additions & 10 deletions mlflow/langchain/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,21 @@ def _save_internal_runnables(runnable, path, loader_fn, persist_dir):
_MODEL_DATA_KEY: _MODEL_DATA_YAML_FILE_NAME,
_MODEL_LOAD_KEY: _CONFIG_LOAD_KEY,
}
path = path / _MODEL_DATA_YAML_FILE_NAME
model_path = path / _MODEL_DATA_YAML_FILE_NAME
# Save some simple runnables that langchain natively supports.
if hasattr(runnable, "save"):
runnable.save(path)
# TODO: check if `dict` is enough to load it back
runnable.save(model_path)
elif hasattr(runnable, "dict"):
runnable_dict = runnable.dict()
with open(path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
try:
runnable_dict = runnable.dict()
with open(model_path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
# if the model cannot be loaded back, then `dict` is not enough for saving.
_load_model_from_config(path, conf)
except Exception:
raise Exception("Cannot save runnable without `save` method.")
else:
return Exception(f"Runnable {runnable} is not supported for saving.")
raise Exception("Cannot save runnable without `save` or `dict` methods.")
return conf


Expand Down Expand Up @@ -320,7 +324,7 @@ def _save_runnable_with_steps(model, file_path: Union[Path, str], loader_fn=None
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[step] = f"{runnable} -- {e}"
unsaved_runnables[step] = f"{runnable.get_name()} -- {e}"

if unsaved_runnables:
raise MlflowException(f"Failed to save runnable sequence: {unsaved_runnables}.")
Expand Down Expand Up @@ -355,7 +359,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[f"{index}-{i}"] = f"{runnable} -- {e}"
unsaved_runnables[f"{index}-{i}"] = f"{runnable.get_name()} -- {e}"

# save default branch
default_branch_path = branches_path / _DEFAULT_BRANCH_NAME
Expand All @@ -365,7 +369,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
model.default, default_branch_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default} -- {e}"
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default.get_name()} -- {e}"
if unsaved_runnables:
raise MlflowException(f"Failed to save runnable branch: {unsaved_runnables}.")

Expand Down
62 changes: 62 additions & 0 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,3 +2843,65 @@ def test_langchain_model_not_streamable():
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
with pytest.raises(MlflowException, match="This model does not support predict_stream method"):
loaded_model.predict_stream({"product": "shoe"})


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"),
reason="feature not existing",
)
def test_langchain_model_save_exception(fake_chat_model):
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser

prompt = PromptTemplate.from_template(
"What's your favorite {industry} company in {country}?", partial_variables={"country": "US"}
)
chain = prompt | fake_chat_model | StrOutputParser()
assert chain.invoke({"industry": "tech"}) == "Databricks"

with pytest.raises(
MlflowException, match=r"Failed to save runnable sequence: {'0': 'PromptTemplate -- "
):
with mlflow.start_run():
mlflow.langchain.log_model(chain, "model_path", input_example={"industry": "tech"})


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"),
reason="feature not existing",
)
def test_langchain_model_save_throws_exception_on_unsupported_runnables(fake_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant."),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)

def retrieve_history(input):
return {"history": [], "question": input["question"], "name": input["name"]}

chain = (
{"question": itemgetter("question"), "name": itemgetter("name")}
| (RunnableLambda(retrieve_history) | prompt | fake_chat_model).with_listeners()
| StrOutputParser()
| RunnablePassthrough()
)
input_example = {"question": "Who owns MLflow?", "name": ""}
assert chain.invoke(input_example) == "Databricks"

with pytest.raises(
MlflowException,
match=r"Failed to save runnable sequence: {'1': 'RunnableSequence "
r"-- Cannot save runnable without `save` method.'",
), mlflow.start_run():
mlflow.langchain.log_model(
chain,
artifact_path="chain",
)

0 comments on commit 4fd44db

Please sign in to comment.