From d41c7751b71635815a95999acb5524f4dd7c9b09 Mon Sep 17 00:00:00 2001 From: Sunish Sheth Date: Thu, 21 Mar 2024 12:34:10 -0700 Subject: [PATCH] Deleting the global context for config path in langchain (#11494) Signed-off-by: Sunish Sheth Signed-off-by: mlflow-automation Co-authored-by: mlflow-automation --- mlflow/langchain/__init__.py | 16 ++++++++++++++-- tests/langchain/test_langchain_model_export.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/mlflow/langchain/__init__.py b/mlflow/langchain/__init__.py index d68f0b1eaa115..bfd5f35c656e4 100644 --- a/mlflow/langchain/__init__.py +++ b/mlflow/langchain/__init__.py @@ -17,6 +17,7 @@ import logging import os import warnings +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Union import cloudpickle @@ -800,10 +801,21 @@ def load_model(model_uri, dst_path=None): return _load_model_from_local_fs(local_model_path) -def _load_code_model(code_path: Optional[str] = None): +@contextmanager +def _config_path_context(code_path: Optional[str] = None): _set_config_path(code_path) + try: + yield + finally: + _set_config_path(None) + - import chain # noqa: F401 +def _load_code_model(code_path: Optional[str] = None): + with _config_path_context(code_path): + try: + import chain # noqa: F401 + except ImportError as e: + raise mlflow.MlflowException("Failed to import LangChain model.") from e return mlflow.langchain._rag_utils.__databricks_rag_chain__ diff --git a/tests/langchain/test_langchain_model_export.py b/tests/langchain/test_langchain_model_export.py index c02a71786b463..4d54006e1ca07 100644 --- a/tests/langchain/test_langchain_model_export.py +++ b/tests/langchain/test_langchain_model_export.py @@ -2178,7 +2178,9 @@ def test_save_load_chain_as_code(): code_paths=["tests/langchain/state_of_the_union.txt"], ) + assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None loaded_model = mlflow.langchain.load_model(model_info.model_uri) + assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None answer = "Databricks" assert loaded_model.invoke(input_example) == answer pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) @@ -2367,7 +2369,9 @@ def test_save_load_chain_as_code_optional_code_path(): code_paths=[], ) + assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None loaded_model = mlflow.langchain.load_model(model_info.model_uri) + assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None answer = "Databricks" assert loaded_model.invoke(input_example) == answer pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) @@ -2393,3 +2397,13 @@ def test_save_load_chain_as_code_optional_code_path(): assert langchain_flavor["databricks_dependency"] == { "databricks_chat_endpoint_name": ["fake-endpoint"] } + + +def test_config_path_context(): + with mlflow.langchain._config_path_context("tests/langchain/config.yml"): + assert ( + mlflow.langchain._rag_utils.__databricks_rag_config_path__ + == "tests/langchain/config.yml" + ) + + assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None