Skip to content

Commit

Permalink
langchain[minor]: Migrate mlflow and databricks classes to deployment…
Browse files Browse the repository at this point in the history
…s APIs. (langchain-ai#13699)

## Description

Related to mlflow/mlflow#10420. MLflow AI
gateway will be deprecated and replaced by the `mlflow.deployments`
module. Happy to split this PR if it's too large.

```
pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain
```

## Dependencies

Install mlflow from mlflow/mlflow#10420:

```
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge
```

## Testing plan

The following code works fine on local and databricks:

<details><summary>Click</summary>
<p>

```python
"""
Setup
-----
mlflow deployments start-server --config-path examples/gateway/openai/config.yaml
databricks secrets create-scope <scope>
databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY

Run
---
python /path/to/this/file.py secrets/<scope>/openai-api-key
"""
from langchain.chat_models import ChatMlflow, ChatDatabricks
from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings
from langchain.llms import Databricks, Mlflow
from langchain.schema.messages import HumanMessage
from langchain.chains.loading import load_chain
from mlflow.deployments import get_deploy_client
import uuid
import sys
import tempfile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

###############################
# MLflow
###############################
chat = ChatMlflow(
    target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings")
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])

llm = Mlflow(
    target_uri="http://127.0.0.1:5000",
    endpoint="completions",
    params={"temperature": 0.1},
)
print(llm("I am"))

llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

###############################
# Databricks
###############################
secret = sys.argv[1]
client = get_deploy_client("databricks")

# External - chat
name = f"chat-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-4",
                    "provider": "openai",
                    "task": "llm/v1/chat",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    chat = ChatDatabricks(
        target_uri="databricks", endpoint=name, params={"temperature": 0.1}
    )
    print(chat([HumanMessage(content="hello")]))
finally:
    client.delete_endpoint(endpoint=name)

# External - embeddings
name = f"embeddings-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "text-embedding-ada-002",
                    "provider": "openai",
                    "task": "llm/v1/embeddings",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name)
    print(embeddings.embed_query("hello")[:3])
    print(embeddings.embed_documents(["hello", "world"])[0][:3])
finally:
    client.delete_endpoint(endpoint=name)

# External - completions
name = f"completions-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-3.5-turbo-instruct",
                    "provider": "openai",
                    "task": "llm/v1/completions",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    llm = Databricks(
        endpoint_name=name,
        model_kwargs={"temperature": 0.1},
    )
    print(llm("I am"))
finally:
    client.delete_endpoint(endpoint=name)


# Foundation model - chat
chat = ChatDatabricks(
    endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

# Foundation model - embeddings
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
print(embeddings.embed_query("hello")[:3])

# Foundation model - completions
llm = Databricks(
    endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1}
)
print(llm("hello"))
llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

```

Output:

```
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context?
Sure, here's a classic one for you:

Why don't scientists trust atoms?

Because they make up everything!
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad
{'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"}
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
 a 23 year old female and I am currently studying for my master's degree
content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?"
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]

hello back
 Well, I don't really know many jokes, but I do know this funny story...
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex
{'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."}
```

</p>
</details>

The existing workflow doesn't break:

<details><summary>click</summary>
<p>

```python
import uuid

import mlflow
from mlflow.models import ModelSignature
from mlflow.types.schema import ColSpec, Schema


class MyModel(mlflow.pyfunc.PythonModel):
    def predict(self, context, model_input):
        return str(uuid.uuid4())


with mlflow.start_run():
    mlflow.pyfunc.log_model(
        "model",
        python_model=MyModel(),
        pip_requirements=["mlflow==2.8.1", "cloudpickle<3"],
        signature=ModelSignature(
            inputs=Schema(
                [
                    ColSpec("string", "prompt"),
                    ColSpec("string", "stop"),
                ]
            ),
            outputs=Schema(
                [
                    ColSpec(name=None, type="string"),
                ]
            ),
        ),
        registered_model_name=f"lang-{uuid.uuid4()}",
    )

# Manually create a serving endpoint with the registered model and run
from langchain.llms import Databricks

llm = Databricks(endpoint_name="<name>")
llm("hello")  # 9d0b2491-3d13-487c-bc02-1287f06ecae7
```

</p>
</details> 

## Follow-up tasks

(This PR is too large. I'll file a separate one for follow-up tasks.)

- Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and
`docs/docs/integrations/providers/databricks.md`.

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
2 people authored and laudanum123 committed Dec 3, 2023
1 parent 5cf287c commit 77e98cd
Show file tree
Hide file tree
Showing 14 changed files with 666 additions and 22 deletions.
4 changes: 4 additions & 0 deletions libs/langchain/langchain/chat_models/__init__.py
Expand Up @@ -24,6 +24,7 @@
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.cohere import ChatCohere
from langchain.chat_models.databricks import ChatDatabricks
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.everlyai import ChatEverlyAI
from langchain.chat_models.fake import FakeListChatModel
Expand All @@ -37,6 +38,7 @@
from langchain.chat_models.konko import ChatKonko
from langchain.chat_models.litellm import ChatLiteLLM
from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow import ChatMlflow
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
Expand All @@ -52,10 +54,12 @@
"AzureChatOpenAI",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatDatabricks",
"ChatEverlyAI",
"ChatAnthropic",
"ChatCohere",
"ChatGooglePalm",
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatOllama",
"ChatVertexAI",
Expand Down
46 changes: 46 additions & 0 deletions libs/langchain/langchain/chat_models/databricks.py
@@ -0,0 +1,46 @@
import logging
from urllib.parse import urlparse

from langchain.chat_models.mlflow import ChatMlflow

logger = logging.getLogger(__name__)


class ChatDatabricks(ChatMlflow):
"""`Databricks` chat models API.
To use, you should have the ``mlflow`` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
Example:
.. code-block:: python
from langchain.chat_models import ChatDatabricks
chat = ChatDatabricks(
target_uri="databricks",
endpoint="chat",
temperature-0.1,
)
"""

target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "databricks-chat"

@property
def _mlflow_extras(self) -> str:
return ""

def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return

if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)
217 changes: 217 additions & 0 deletions libs/langchain/langchain/chat_models/mlflow.py
@@ -0,0 +1,217 @@
import asyncio
import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import (
Field,
PrivateAttr,
)

logger = logging.getLogger(__name__)


class ChatMlflow(BaseChatModel):
"""`MLflow` chat models API.
To use, you should have the `mlflow[genai]` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
Example:
.. code-block:: python
from langchain.chat_models import ChatMlflow
chat = ChatMlflow(
target_uri="http://localhost:5000",
endpoint="chat",
temperature-0.1,
)
"""

endpoint: str
"""The endpoint to use."""
target_uri: str
"""The target URI to use."""
temperature: float = 0.0
"""The sampling temperature."""
n: int = 1
"""The number of completion choices to generate."""
stop: Optional[List[str]] = None
"""The stop sequence."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client

self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
"required dependencies."
) from e

@property
def _mlflow_extras(self) -> str:
return "[genai]"

def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
allowed = ["http", "https", "databricks"]
if urlparse(self.target_uri).scheme not in allowed:
raise ValueError(
f"Invalid target URI: {self.target_uri}. "
f"The scheme must be one of {allowed}."
)

@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"target_uri": self.target_uri,
"endpoint": self.endpoint,
"temperature": self.temperature,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
}
return params

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatMlflow._convert_message_to_dict(message) for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
"temperature": self.temperature,
"n": self.n,
"stop": stop or self.stop,
"max_tokens": self.max_tokens,
**self.extra_params,
}

resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params

def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "mlflow-chat"

@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)

@staticmethod
def _raise_functions_not_supported() -> None:
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)

@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
else:
raise ValueError(f"Got unknown message type: {message}")

if "function_call" in message.additional_kwargs:
ChatMlflow._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by Databricks"
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict

@staticmethod
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["choices"]:
message = ChatMlflow._convert_dict_to_message(choice["message"])
usage = choice.get("usage", {})
gen = ChatGeneration(
message=message,
generation_info=usage,
)
generations.append(gen)

usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)
6 changes: 6 additions & 0 deletions libs/langchain/langchain/chat_models/mlflow_ai_gateway.py
@@ -1,5 +1,6 @@
import asyncio
import logging
import warnings
from functools import partial
from typing import Any, Dict, List, Mapping, Optional

Expand Down Expand Up @@ -59,6 +60,11 @@ class ChatMLflowAIGateway(BaseChatModel):
"""

def __init__(self, **kwargs: Any):
warnings.warn(
"`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or "
"`ChatDatabricks` instead.",
DeprecationWarning,
)
try:
import mlflow.gateway
except ImportError as e:
Expand Down
4 changes: 4 additions & 0 deletions libs/langchain/langchain/embeddings/__init__.py
Expand Up @@ -26,6 +26,7 @@
from langchain.embeddings.clarifai import ClarifaiEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.dashscope import DashScopeEmbeddings
from langchain.embeddings.databricks import DatabricksEmbeddings
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.edenai import EdenAiEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
Expand All @@ -50,6 +51,7 @@
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
from langchain.embeddings.localai import LocalAIEmbeddings
from langchain.embeddings.minimax import MiniMaxEmbeddings
from langchain.embeddings.mlflow import MlflowEmbeddings
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
Expand Down Expand Up @@ -78,6 +80,7 @@
"CacheBackedEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",
"ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings",
Expand All @@ -87,6 +90,7 @@
"JinaEmbeddings",
"LlamaCppEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowEmbeddings",
"MlflowAIGatewayEmbeddings",
"ModelScopeEmbeddings",
"TensorflowHubEmbeddings",
Expand Down
45 changes: 45 additions & 0 deletions libs/langchain/langchain/embeddings/databricks.py
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import Iterator, List
from urllib.parse import urlparse

from langchain.embeddings.mlflow import MlflowEmbeddings


def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]


class DatabricksEmbeddings(MlflowEmbeddings):
"""Wrapper around embeddings LLMs in Databricks.
To use, you should have the ``mlflow`` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
Example:
.. code-block:: python
from langchain.embeddings import DatabricksEmbeddings
embeddings = DatabricksEmbeddings(
target_uri="databricks",
endpoint="embeddings",
)
"""

target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""

@property
def _mlflow_extras(self) -> str:
return ""

def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return

if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)

0 comments on commit 77e98cd

Please sign in to comment.