Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up langchain tests #10845

Merged
merged 1 commit into from Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 1 addition & 57 deletions mlflow/langchain/utils.py
Expand Up @@ -6,7 +6,7 @@
import types
from functools import lru_cache
from importlib.util import find_spec
from typing import Any, List, NamedTuple, Optional
from typing import NamedTuple

import cloudpickle
import yaml
Expand Down Expand Up @@ -470,59 +470,3 @@ def _load_base_lcs(

model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **kwargs)
return model


# This is an internal function that is used to generate
# a fake chat model for testing purposes.
# cloudpickle can not pickle a pydantic model defined
# within the same scope, so put it here.
def _fake_simple_chat_model():
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage

class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""

def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "Databricks"

@property
def _llm_type(self) -> str:
return "fake chat model"

return FakeChatModel


def _fake_mlflow_question_classifier():
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage

class FakeMLflowClassifier(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""

def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if "MLflow" in messages[0].content.split(":")[1]:
return "yes"
if "cat" in messages[0].content.split(":")[1]:
return "no"
return "unknown"

@property
def _llm_type(self) -> str:
return "fake mlflow classifier"

return FakeMLflowClassifier
117 changes: 79 additions & 38 deletions tests/langchain/test_langchain_model_export.py
Expand Up @@ -199,6 +199,60 @@ def _call(self, inputs: Dict[str, str], run_manager=None) -> Dict[str, str]:
return {"baz": "bar"}


@pytest.fixture
def fake_chat_model():
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage

class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""

def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "Databricks"

@property
def _llm_type(self) -> str:
return "fake chat model"

return FakeChatModel()


@pytest.fixture
def fake_classifier_chat_model():
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage

class FakeMLflowClassifier(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""

def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if "MLflow" in messages[0].content.split(":")[1]:
return "yes"
if "cat" in messages[0].content.split(":")[1]:
return "no"
return "unknown"

@property
def _llm_type(self) -> str:
return "fake mlflow classifier"

return FakeMLflowClassifier()


def test_langchain_native_save_and_load_model(model_path):
model = create_openai_llmchain()
mlflow.langchain.save_model(model, model_path)
Expand Down Expand Up @@ -891,13 +945,11 @@ def mul_two(x):
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_predict_with_callbacks():
def test_predict_with_callbacks(fake_chat_model):
from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

from mlflow.langchain.utils import _fake_simple_chat_model

class TestCallbackHandler(BaseCallbackHandler):
def __init__(self):
super().__init__()
Expand All @@ -911,9 +963,8 @@ def on_llm_start(
) -> Any:
self.num_llm_start_calls += 1

chat_model = _fake_simple_chat_model()()
prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
chain = prompt | chat_model | StrOutputParser()
chain = prompt | fake_chat_model | StrOutputParser()
# Test the basic functionality of the chain
assert chain.invoke({"industry": "tech"}) == "Databricks"

Expand Down Expand Up @@ -956,15 +1007,12 @@ def on_llm_start(
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_predict_with_callbacks_supports_chat_response_conversion():
def test_predict_with_callbacks_supports_chat_response_conversion(fake_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

from mlflow.langchain.utils import _fake_simple_chat_model

chat_model = _fake_simple_chat_model()()
prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
chain = prompt | chat_model | StrOutputParser()
chain = prompt | fake_chat_model | StrOutputParser()
# Test the basic functionality of the chain
assert chain.invoke({"industry": "tech"}) == "Databricks"

Expand Down Expand Up @@ -1204,17 +1252,14 @@ def test_save_load_complex_runnable_sequence():
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_save_load_simple_chat_model(spark):
def test_save_load_simple_chat_model(spark, fake_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

from mlflow.langchain.utils import _fake_simple_chat_model

prompt = ChatPromptTemplate.from_template(
"What is a good name for a company that makes {product}?"
)
chat_model = _fake_simple_chat_model()()
chain = prompt | chat_model | StrOutputParser()
chain = prompt | fake_chat_model | StrOutputParser()
assert chain.invoke({"product": "MLflow"}) == "Databricks"
# signature is required for spark_udf
signature = infer_signature({"product": "MLflow"}, "Databricks")
Expand Down Expand Up @@ -1250,15 +1295,11 @@ def test_save_load_simple_chat_model(spark):
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_save_load_rag(tmp_path, spark):
def test_save_load_rag(tmp_path, spark, fake_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

from mlflow.langchain.utils import _fake_simple_chat_model

chat_model = _fake_simple_chat_model()()

# Create the vector db, persist the db to a local fs folder
loader = TextLoader("tests/langchain/state_of_the_union.txt")
documents = loader.load()
Expand All @@ -1284,7 +1325,7 @@ def load_retriever(persist_directory):
"question": RunnablePassthrough(),
}
| prompt
| chat_model
| fake_chat_model
| StrOutputParser()
)
question = "What is a good name for a company that makes MLflow?"
Expand Down Expand Up @@ -1367,17 +1408,14 @@ def test_runnable_branch_save_load():
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_complex_runnable_branch_save_load():
def test_complex_runnable_branch_save_load(fake_chat_model, fake_classifier_chat_model):
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableBranch, RunnableLambda

from mlflow.langchain.utils import _fake_mlflow_question_classifier, _fake_simple_chat_model

chat_model = _fake_mlflow_question_classifier()()
prompt = ChatPromptTemplate.from_template("{question_is_relevant}\n{query}")
# Need to add prompt here as the chat model doesn't accept dict input
answer_model = prompt | _fake_simple_chat_model()()
answer_model = prompt | fake_chat_model

decline_to_answer = RunnableLambda(
lambda x: "I cannot answer questions that are not about MLflow."
Expand All @@ -1398,7 +1436,7 @@ def test_complex_runnable_branch_save_load():
chain = (
{
"question_is_relevant": is_question_about_mlflow_prompt
| chat_model
| fake_classifier_chat_model
| StrOutputParser(),
"query": itemgetter("query"),
}
Expand Down Expand Up @@ -1448,12 +1486,10 @@ def test_complex_runnable_branch_save_load():
@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_chat_with_history(spark):
def test_chat_with_history(spark, fake_chat_model):
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda

from mlflow.langchain.utils import _fake_simple_chat_model

prompt_with_history_str = """
Here is a history between you and a human: {chat_history}

Expand All @@ -1464,8 +1500,6 @@ def test_chat_with_history(spark):
input_variables=["chat_history", "question"], template=prompt_with_history_str
)

chat_model = _fake_simple_chat_model()()

def extract_question(input):
return input[-1]["content"]

Expand All @@ -1478,7 +1512,7 @@ def extract_history(input):
"chat_history": itemgetter("messages") | RunnableLambda(extract_history),
}
| prompt_with_history
| chat_model
| fake_chat_model
| StrOutputParser()
)

Expand Down Expand Up @@ -1528,14 +1562,17 @@ def extract_history(input):
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_predict_with_builtin_pyfunc_chat_conversion():
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.output_parser import StrOutputParser

from mlflow.langchain.utils import _fake_simple_chat_model

class ChatModel(_fake_simple_chat_model()):
class ChatModel(SimpleChatModel):
def _call(self, messages, stop, run_manager, **kwargs): # pylint: disable=signature-differs
return "\n".join([f"{message.type}: {message.content}" for message in messages])

@property
def _llm_type(self) -> str:
return "chat model"

input_example = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -1614,12 +1651,16 @@ def _call(self, messages, stop, run_manager, **kwargs): # pylint: disable=signa
Version(langchain.__version__) < Version("0.0.311"), reason="feature not existing"
)
def test_predict_with_builtin_pyfunc_chat_conversion_for_aimessage_response():
from mlflow.langchain.utils import _fake_simple_chat_model
from langchain.chat_models.base import SimpleChatModel

class ChatModel(_fake_simple_chat_model()):
class ChatModel(SimpleChatModel):
def _call(self, messages, stop, run_manager, **kwargs): # pylint: disable=signature-differs
return "You own MLflow"

@property
def _llm_type(self) -> str:
return "chat model"

input_example = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down