Skip to content

Commit

Permalink
check for any Runnable in call_api()
Browse files Browse the repository at this point in the history
Signed-off-by: Ishaan Mehta <45380942+ishaan-mehta@users.noreply.github.com>
  • Loading branch information
ishaan-mehta committed Mar 28, 2024
1 parent 2248511 commit 8cf51c3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
5 changes: 2 additions & 3 deletions mlflow/langchain/api_request_parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def call_api(
Calls the LangChain API and stores results.
"""
from langchain.schema import BaseRetriever

from mlflow.langchain.utils import lc_runnables_types
from langchain_core.runnables import Runnable

_logger.debug(f"Request #{self.index} started with payload: {self.request_json}")

Expand All @@ -191,7 +190,7 @@ def call_api(
response = [
{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs
]
elif isinstance(self.lc_model, lc_runnables_types()):
elif isinstance(self.lc_model, Runnable):
if isinstance(self.request_json, dict):
# This is a temporary fix for the case when spark_udf converts
# input into pandas dataframe with column name, while the model
Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ extend-exclude = [
"tests/protos",
]

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 88

[tool.ruff.lint]
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
select = [
Expand Down

0 comments on commit 8cf51c3

Please sign in to comment.