Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support subclasses of chains for langchain flavor #8453

Merged
merged 12 commits into from
May 26, 2023

Conversation

liangz1
Copy link
Collaborator

@liangz1 liangz1 commented May 17, 2023

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

This PR extends the mlflow langchain flavor to support saving and loading all subclasses of Chain.

Before:

class of model model log
& load
(langchain
& pyfunc)
extract llm
type in model
metadata
model
predict
pyfunc
spark_udf
LLMChain
AgentExecutor
other

After:

class of model model log
& load
(langchain
& pyfunc)
extract llm
info in model
metadata
model
predict
pyfunc
spark_udf
subclasses of
LLMChain
subclasses of
AgentExecutor
subclasses
of Chain
(see known
exceptions
below)
❌ (llm info
exists in
model.yaml)
✅* ✅*
Chains
containing
memory

* Given all the Chains have some bugs in SerDe, we cannot test any concrete Chains. It should work and we can add tests after langchain fixed the bugs.

Here is a list of Chains that langchain supports loading.
https://github.com/hwchase17/langchain/blob/0c3de0a0b32fadb8caf3e6d803287229409f9da9/langchain/chains/loading.py#L409

type_to_loader_dict = {
    "api_chain": _load_api_chain, # ❌ SISO requires "requests_wrapper" in kwargs
    "hyde_chain": _load_hyde_chain, # ❌ SISO requires "embeddings" in kwargs
    "llm_chain": _load_llm_chain, # SISO ✅
    "llm_bash_chain": _load_llm_bash_chain, # SISO ❌ Bug, fixing by langchain
    "llm_checker_chain": _load_llm_checker_chain, # ❌ SISO, but contains SequentialChain, which does not support saving
    "llm_math_chain": _load_llm_math_chain, # SISO ❌ Bug, fixing by langchain
    "llm_requests_chain": _load_llm_requests_chain, # ❌ SISO requires "requests_wrapper" in kwargs
    "pal_chain": _load_pal_chain, # SIMO ❌ Bug, reported to langchain ["result", "intermediate_steps"]
    "qa_with_sources_chain": _load_qa_with_sources_chain, # ❌ MIMO [self.input_docs_key, self.question_key] -> [self.answer_key, self.sources_answer_key, "source_documents"]
    "stuff_documents_chain": _load_stuff_documents_chain, # MISO ❌ Bug, fixing by langchain
    "map_reduce_documents_chain": _load_map_reduce_documents_chain, # ❌ MIMO ["result", "intermediate_steps"]
    "map_rerank_documents_chain": _load_map_rerank_documents_chain, # ❌ MIMO ["result", "intermediate_steps"]
    "refine_documents_chain": _load_refine_documents_chain, # ❌ MIMO ["result", "intermediate_steps"]
    "sql_database_chain": _load_sql_database_chain, # ❌ SIMO ["result", "intermediate_steps"], requires "database" in kwargs 
    "vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, # ❌ SIMO requires "vectorstore" in kwargs
    "vector_db_qa": _load_vector_db_qa, # ❌ SIMO requires "vectorstore" in kwargs
}

Among them, MLflow cannot support the following number of chains and reasons:

  1. Multiple Outputs: 8 chains are blocked. mlflow plans to have a follow-up PR to support it.
  2. Needs kwargs during mlflow.pyfunc.load_model: 6 chains are blocked. mlflow plans to have a follow-up PR to support it.
  3. langchain bug: 4 chains are blocked (LLMBashChain, LLMMathChain, StuffDocumentsChain, PALChain). Filed HypotheticalDocumentEmbedder loading fails langchain-ai/langchain#5131, StuffDocumentsChain input_keys does not contain "question" langchain-ai/langchain#5160, PALChain loading fails langchain-ai/langchain#5224

Known exceptions

Examples of Chain classes that are known to be not supported as of langchain==0.0.176:

  1. class ConversationChain(LLMChain): Chain to have a conversation and load context from memory. Reason: (from langchain) Saving of memory is not yet supported.

How is this patch tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests (describe details, including test results, below)

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly in the documentation preview.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

mlflow langchain flavor allows logging and loading all subclasses of Chain, as long as their serialization / deserialization methods are implemented by langchain.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs. labels May 17, 2023
@github-actions

This comment was marked as resolved.

@mlflow-automation
Copy link
Collaborator

mlflow-automation commented May 17, 2023

Documentation preview for fc7920d will be available here when this CircleCI job completes successfully.

More info

@WeichenXu123
Copy link
Collaborator

@liangz1

Could you also manually test this for dolly ? #8460

(We don't need to add dolly tests in CI)

Copy link

@minkin-koantek minkin-koantek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth trying

@WeichenXu123
Copy link
Collaborator

@minkin-koantek could you help test ?

You can install this branch of the PR via:

pip install git+https://github.com/liangz1/mlflow.git@support-more-chains

@minkin-koantek
Copy link

I will test it in the next 12 hr thx!

@liangz1
Copy link
Collaborator Author

liangz1 commented May 19, 2023

@WeichenXu123 Thanks for linking the issue.
@minkin-koantek If you don't mind sharing the code to reproduce this issue #8460, I can also help take a look. Thanks!

Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
@minkin-koantek
Copy link

@WeichenXu123 Thanks for linking the issue. @minkin-koantek If you don't mind sharing the code to reproduce this issue #8460, I can also help take a look. Thanks!

The setup is based on a Databricks Dolly script from dbdemos. If I am unsuccessful in getting the model to register I will isolate it so you can try it

@minkin-koantek
Copy link

minkin-koantek commented May 19, 2023

image

The type of my model tested is langchain.chains.combine_documents.stuff.StuffDocumentsChain

Code used to deploy:


  with mlflow.start_run() as run:
      # Save model to MLFlow
      # Note that this only saves the langchain pipeline (we could also add the ChatBot with a custom Model Wrapper class)
      # See https://mlflow.org/docs/latest/models.html#custom-python-models for an example
      # The vector database lives outside of your model
      mlflow.langchain.log_model(llm_model, artifact_path="model")
      model_registered = mlflow.register_model(f"runs:/{run.info.run_id}/model", "gardening-dolly-7b-bot")

  # Move the model in production
  client = mlflow.tracking.MlflowClient()
  print("registering model version "+model_registered.version+" as production model")
  client.transition_model_version_stage("gardening-dolly-7b-bot", model_registered.version, stage = "Production", archive_existing_versions=True)

publish_model_to_mlflow(qa_chain)

Resulting message after fixing promotion code:

Registered model 'gardening-dolly-7b-bot' already exists. Creating a new version of this model...
2023/05/19 14:13:37 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: gardening-dolly-7b-bot, version 3
Created version '3' of model 'gardening-dolly-7b-bot'.
registering model version 3 as production model```

@minkin-koantek
Copy link

minkin-koantek commented May 19, 2023

We're not out of the woods yet.

I tried using the standard batch inference usage of the model having installed the MLFlow 2.3.3..dev0 using this code:

import mlflow
logged_model = 'runs:/beca191e59f54a2e9777186774058e96/model'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)

# Predict on a Pandas DataFrame.
import pandas as pd
loaded_model.predict(df.toPandas())

This was the error:

2023/05/19 16:05:18 WARNING mlflow.pyfunc: Detected one or more mismatches between the model's dependencies and the current Python environment:
 - numpy (current: 1.21.5, required: numpy==1.24.3)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.
2023/05/19 16:05:19 WARNING mlflow.langchain.api_request_parallel_processor: Request #0 failed with AttributeError("'str' object has no attribute 'page_content'")
2023/05/19 16:05:19 WARNING mlflow.langchain.api_request_parallel_processor: Request #1 failed with AttributeError("'str' object has no attribute 'page_content'")
2023/05/19 16:05:19 WARNING mlflow.langchain.api_request_parallel_processor: Request #2 failed with AttributeError("'str' object has no attribute 'page_content'")
2023/05/19 16:05:19 WARNING mlflow.langchain.api_request_parallel_processor: Request #3 failed with AttributeError("'str' object has no attribute 'page_content'")
2023/05/19 16:05:19 WARNING mlflow.langchain.api_request_parallel_processor: Request #4 failed with AttributeError("'str' object has no attribute 'page_content'")


> Entering new StuffDocumentsChain chain...


> Entering new StuffDocumentsChain chain...


> Entering new StuffDocumentsChain chain...


> Entering new StuffDocumentsChain chain...


> Entering new StuffDocumentsChain chain...
MlflowException: 5 tasks failed. See logs for details.
---------------------------------------------------------------------------
MlflowException                           Traceback (most recent call last)
File <command-3944480421186138>:9
      7 # Predict on a Pandas DataFrame.
      8 import pandas as pd
----> 9 loaded_model.predict(df.toPandas())

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266e36af-e081-49ac-a9b9-36047b4ff7a4/lib/python3.10/site-packages/mlflow/pyfunc/__init__.py:427, in PyFuncModel.predict(self, data)
    424         if _MLFLOW_OPENAI_TESTING.get():
    425             raise
--> 427 return self._predict_fn(data)

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266e36af-e081-49ac-a9b9-36047b4ff7a4/lib/python3.10/site-packages/mlflow/langchain/__init__.py:444, in _LangChainModelWrapper.predict(self, data)
    440 else:
    441     raise mlflow.MlflowException.invalid_parameter_value(
    442         "Input must be a pandas DataFrame or a list of strings or a list of dictionaries",
    443     )
--> 444 return process_api_requests(lc_model=self.lc_model, requests=messages)```

Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
mlflow/pyfunc/__init__.py Outdated Show resolved Hide resolved
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
@liangz1 liangz1 requested review from harupy and dbczumar May 24, 2023 15:15
@liangz1 liangz1 changed the title [WIP] Support more chains for langchain flavor Support subclasses of chains for langchain flavor May 24, 2023
Comment on lines +281 to +288
def test_langchain_native_log_and_load_qa_with_sources_chain():
# StuffDocumentsChain is a subclass of Chain
model = create_qa_with_sources_chain()
with mlflow.start_run():
logged_model = mlflow.langchain.log_model(model, "langchain_model")

loaded_model = mlflow.langchain.load_model(logged_model.model_uri)
assert model == loaded_model
Copy link
Collaborator

@dbczumar dbczumar May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add at test case for pyfunc predict / spark UDF too (I know we're tracking these internally)?

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once predict() and spark_udf() coverage is added :)

@dbczumar
Copy link
Collaborator

Thanks @liangz1 !

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@liangz1 liangz1 merged commit 8772635 into mlflow:master May 26, 2023
37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants