Skip to content

Commit

Permalink
Deleting the global context for config path in langchain (#11494)
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
  • Loading branch information
2 people authored and BenWilson2 committed Mar 21, 2024
1 parent 2d53d94 commit d41c775
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 14 additions & 2 deletions mlflow/langchain/__init__.py
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union

import cloudpickle
Expand Down Expand Up @@ -800,10 +801,21 @@ def load_model(model_uri, dst_path=None):
return _load_model_from_local_fs(local_model_path)


def _load_code_model(code_path: Optional[str] = None):
@contextmanager
def _config_path_context(code_path: Optional[str] = None):
_set_config_path(code_path)
try:
yield
finally:
_set_config_path(None)


import chain # noqa: F401
def _load_code_model(code_path: Optional[str] = None):
with _config_path_context(code_path):
try:
import chain # noqa: F401
except ImportError as e:
raise mlflow.MlflowException("Failed to import LangChain model.") from e

return mlflow.langchain._rag_utils.__databricks_rag_chain__

Expand Down
14 changes: 14 additions & 0 deletions tests/langchain/test_langchain_model_export.py
Expand Up @@ -2178,7 +2178,9 @@ def test_save_load_chain_as_code():
code_paths=["tests/langchain/state_of_the_union.txt"],
)

assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None
loaded_model = mlflow.langchain.load_model(model_info.model_uri)
assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None
answer = "Databricks"
assert loaded_model.invoke(input_example) == answer
pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Expand Down Expand Up @@ -2367,7 +2369,9 @@ def test_save_load_chain_as_code_optional_code_path():
code_paths=[],
)

assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None
loaded_model = mlflow.langchain.load_model(model_info.model_uri)
assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None
answer = "Databricks"
assert loaded_model.invoke(input_example) == answer
pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
Expand All @@ -2393,3 +2397,13 @@ def test_save_load_chain_as_code_optional_code_path():
assert langchain_flavor["databricks_dependency"] == {
"databricks_chat_endpoint_name": ["fake-endpoint"]
}


def test_config_path_context():
with mlflow.langchain._config_path_context("tests/langchain/config.yml"):
assert (
mlflow.langchain._rag_utils.__databricks_rag_config_path__
== "tests/langchain/config.yml"
)

assert mlflow.langchain._rag_utils.__databricks_rag_config_path__ is None

0 comments on commit d41c775

Please sign in to comment.