Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix langchain model saving with dict error message #11822

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,9 @@ def _save_model(model, path, loader_fn, persist_dir):
"using `pip install cloudpickle>=2.1.0` "
"to ensure the model can be loaded correctly."
)
with register_pydantic_v1_serializer_cm():
# patch_langchain_type_to_cls_dict here as we attempt to load model
# if it's saved by `dict` method
with register_pydantic_v1_serializer_cm(), patch_langchain_type_to_cls_dict():
if isinstance(model, lc_runnables_types()):
return _save_runnables(model, path, loader_fn=loader_fn, persist_dir=persist_dir)
else:
Expand Down
24 changes: 14 additions & 10 deletions mlflow/langchain/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,21 @@ def _save_internal_runnables(runnable, path, loader_fn, persist_dir):
_MODEL_DATA_KEY: _MODEL_DATA_YAML_FILE_NAME,
_MODEL_LOAD_KEY: _CONFIG_LOAD_KEY,
}
path = path / _MODEL_DATA_YAML_FILE_NAME
model_path = path / _MODEL_DATA_YAML_FILE_NAME
# Save some simple runnables that langchain natively supports.
if hasattr(runnable, "save"):
runnable.save(path)
# TODO: check if `dict` is enough to load it back
runnable.save(model_path)
elif hasattr(runnable, "dict"):
runnable_dict = runnable.dict()
with open(path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
try:
runnable_dict = runnable.dict()
with open(model_path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
# if the model cannot be loaded back, then `dict` is not enough for saving.
_load_model_from_config(path, conf)
except Exception:
raise Exception("Cannot save runnable without `save` method.")
else:
return Exception(f"Runnable {runnable} is not supported for saving.")
raise Exception("Cannot save runnable without `save` or `dict` methods.")
return conf


Expand Down Expand Up @@ -320,7 +324,7 @@ def _save_runnable_with_steps(model, file_path: Union[Path, str], loader_fn=None
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[step] = f"{runnable} -- {e}"
unsaved_runnables[step] = f"{runnable.get_name()} -- {e}"

if unsaved_runnables:
raise MlflowException(f"Failed to save runnable sequence: {unsaved_runnables}.")
Expand Down Expand Up @@ -355,7 +359,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[f"{index}-{i}"] = f"{runnable} -- {e}"
unsaved_runnables[f"{index}-{i}"] = f"{runnable.get_name()} -- {e}"

# save default branch
default_branch_path = branches_path / _DEFAULT_BRANCH_NAME
Expand All @@ -365,7 +369,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
model.default, default_branch_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default} -- {e}"
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default.get_name()} -- {e}"
if unsaved_runnables:
raise MlflowException(f"Failed to save runnable branch: {unsaved_runnables}.")

Expand Down
62 changes: 62 additions & 0 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,3 +2843,65 @@ def test_langchain_model_not_streamable():
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
with pytest.raises(MlflowException, match="This model does not support predict_stream method"):
loaded_model.predict_stream({"product": "shoe"})


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"),
reason="feature not existing",
)
def test_langchain_model_save_exception(fake_chat_model):
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser

prompt = PromptTemplate.from_template(
"What's your favorite {industry} company in {country}?", partial_variables={"country": "US"}
)
chain = prompt | fake_chat_model | StrOutputParser()
assert chain.invoke({"industry": "tech"}) == "Databricks"

with pytest.raises(
MlflowException, match=r"Failed to save runnable sequence: {'0': 'PromptTemplate -- "
):
with mlflow.start_run():
mlflow.langchain.log_model(chain, "model_path", input_example={"industry": "tech"})


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"),
reason="feature not existing",
)
def test_langchain_model_save_throws_exception_on_unsupported_runnables(fake_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant."),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)

def retrieve_history(input):
return {"history": [], "question": input["question"], "name": input["name"]}

chain = (
{"question": itemgetter("question"), "name": itemgetter("name")}
| (RunnableLambda(retrieve_history) | prompt | fake_chat_model).with_listeners()
| StrOutputParser()
| RunnablePassthrough()
)
input_example = {"question": "Who owns MLflow?", "name": ""}
assert chain.invoke(input_example) == "Databricks"

with pytest.raises(
MlflowException,
match=r"Failed to save runnable sequence: {'1': 'RunnableSequence "
r"-- Cannot save runnable without `save` method.'",
), mlflow.start_run():
mlflow.langchain.log_model(
chain,
artifact_path="chain",
)
Loading