Skip to content

Commit

Permalink
Fix langchain model saving with dict error message (#11822)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan authored and BenWilson2 committed May 7, 2024
1 parent def3f85 commit c52493c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
4 changes: 3 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def _save_model(model, path, loader_fn, persist_dir):
"using `pip install cloudpickle>=2.1.0` "
"to ensure the model can be loaded correctly."
)
with register_pydantic_v1_serializer_cm():
# patch_langchain_type_to_cls_dict here as we attempt to load model
# if it's saved by `dict` method
with register_pydantic_v1_serializer_cm(), patch_langchain_type_to_cls_dict():
if isinstance(model, lc_runnables_types()):
return _save_runnables(model, path, loader_fn=loader_fn, persist_dir=persist_dir)
else:
Expand Down
24 changes: 14 additions & 10 deletions mlflow/langchain/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,21 @@ def _save_internal_runnables(runnable, path, loader_fn, persist_dir):
_MODEL_DATA_KEY: _MODEL_DATA_YAML_FILE_NAME,
_MODEL_LOAD_KEY: _CONFIG_LOAD_KEY,
}
path = path / _MODEL_DATA_YAML_FILE_NAME
model_path = path / _MODEL_DATA_YAML_FILE_NAME
# Save some simple runnables that langchain natively supports.
if hasattr(runnable, "save"):
runnable.save(path)
# TODO: check if `dict` is enough to load it back
runnable.save(model_path)
elif hasattr(runnable, "dict"):
runnable_dict = runnable.dict()
with open(path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
try:
runnable_dict = runnable.dict()
with open(model_path, "w") as f:
yaml.dump(runnable_dict, f, default_flow_style=False)
# if the model cannot be loaded back, then `dict` is not enough for saving.
_load_model_from_config(path, conf)
except Exception:
raise Exception("Cannot save runnable without `save` method.")
else:
return Exception(f"Runnable {runnable} is not supported for saving.")
raise Exception("Cannot save runnable without `save` or `dict` methods.")
return conf


Expand Down Expand Up @@ -320,7 +324,7 @@ def _save_runnable_with_steps(model, file_path: Union[Path, str], loader_fn=None
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[step] = f"{runnable} -- {e}"
unsaved_runnables[step] = f"{runnable.get_name()} -- {e}"

if unsaved_runnables:
raise MlflowException(f"Failed to save runnable sequence: {unsaved_runnables}.")
Expand Down Expand Up @@ -355,7 +359,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
runnable, save_runnable_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[f"{index}-{i}"] = f"{runnable} -- {e}"
unsaved_runnables[f"{index}-{i}"] = f"{runnable.get_name()} -- {e}"

# save default branch
default_branch_path = branches_path / _DEFAULT_BRANCH_NAME
Expand All @@ -365,7 +369,7 @@ def _save_runnable_branch(model, file_path, loader_fn, persist_dir):
model.default, default_branch_path, loader_fn, persist_dir
)
except Exception as e:
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default} -- {e}"
unsaved_runnables[_DEFAULT_BRANCH_NAME] = f"{model.default.get_name()} -- {e}"
if unsaved_runnables:
raise MlflowException(f"Failed to save runnable branch: {unsaved_runnables}.")

Expand Down

0 comments on commit c52493c

Please sign in to comment.