Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mistral AI as a new provider in LLM Deployment #11020

Merged
merged 6 commits into from Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/llms/deployments/index.rst
Expand Up @@ -259,6 +259,9 @@ below can be used as a helpful guide when configuring a given endpoint for any n
| AWS Bedrock | - Amazon Titan | N/A | N/A |
| | - Third-party providers | | |
+--------------------------+--------------------------+--------------------------+--------------------------+
| Mistral | - mistral-tiny | N/A | - mistral-embed |
| | - mistral-small | | |
+--------------------------+--------------------------+--------------------------+--------------------------+


† Llama 2 is licensed under the `LLAMA 2 Community License <https://ai.meta.com/llama/license/>`_, Copyright © Meta Platforms, Inc. All Rights Reserved.
Expand Down Expand Up @@ -303,6 +306,7 @@ As of now, the MLflow Deployments Server supports the following providers:
* **huggingface text generation inference**: This is used for models deployed using `Huggingface Text Generation Inference <https://huggingface.co/docs/text-generation-inference/index>`_.
* **ai21labs**: This is used for models offered by `AI21 Labs <https://studio.ai21.com/foundation-models>`_.
* **bedrock**: This is used for models offered by `AWS Bedrock <https://aws.amazon.com/bedrock/>`_.
* **mistral**: This is used for models offered by `Mistral <https://docs.mistral.ai/>`_.

More providers are being added continually. Check the latest version of the MLflow Deployments Server Docs for the
most up-to-date list of supported providers.
Expand Down Expand Up @@ -511,6 +515,7 @@ Each endpoint has the following configuration parameters:
- "huggingface-text-generation-inference"
- "ai21labs"
- "bedrock"
- "mistral"

- **name**: This is an optional field to specify the name of the model.
- **config**: This contains provider-specific configuration details.
Expand Down Expand Up @@ -682,6 +687,16 @@ To match your user's interaction and security access requirements, adjust the ``
+----------------------------+----------+---------+-----------------------------------------------------------------------------------------------+


Mistral
++++++
BenWilson2 marked this conversation as resolved.
Show resolved Hide resolved

+--------------------------+----------+--------------------------+-------------------------------------------------------+
| Configuration Parameter | Required | Default | Description |
+==========================+==========+==========================+=======================================================+
| **mistral_api_key** | Yes | N/A | This is the API key for the Mistral service. |
+--------------------------+----------+--------------------------+-------------------------------------------------------+


An example configuration for Azure OpenAI is:

.. code-block:: yaml
Expand Down
15 changes: 15 additions & 0 deletions docs/source/llms/gateway/index.rst
Expand Up @@ -302,6 +302,9 @@ below can be used as a helpful guide when configuring a given route for any newl
| AWS Bedrock | - Amazon Titan | N/A | N/A |
| | - Third-party providers | | |
+--------------------------+--------------------------+--------------------------+--------------------------+
| Mistral | - mistral-tiny | N/A | - mistral-embed |
| | - mistral-small | | |
+--------------------------+--------------------------+--------------------------+--------------------------+


† Llama 2 is licensed under the `LLAMA 2 Community License <https://ai.meta.com/llama/license/>`_, Copyright © Meta Platforms, Inc. All Rights Reserved.
Expand Down Expand Up @@ -343,6 +346,7 @@ As of now, the MLflow AI Gateway supports the following providers:
* **huggingface text generation inference**: This is used for models deployed using `Huggingface Text Generation Inference <https://huggingface.co/docs/text-generation-inference/index>`_.
* **ai21labs**: This is used for models offered by `AI21 Labs <https://studio.ai21.com/foundation-models>`_.
* **bedrock**: This is used for models offered by `AWS Bedrock <https://aws.amazon.com/bedrock/>`_.
* **mistral**: This is used for models offered by `Mistral <https://docs.mistral.ai/>`_.

More providers are being added continually. Check the latest version of the MLflow AI Gateway Docs for the
most up-to-date list of supported providers.
Expand Down Expand Up @@ -540,6 +544,7 @@ Each route has the following configuration parameters:
- "huggingface-text-generation-inference"
- "ai21labs"
- "bedrock"
- "mistral"

- **name**: This is an optional field to specify the name of the model.
- **config**: This contains provider-specific configuration details.
Expand Down Expand Up @@ -639,6 +644,16 @@ Top-level model configuration for AWS Bedrock routes must be one of the followin
+--------------------------+----------+------------------------------+-------------------------------------------------------+


Mistral
++++++
BenWilson2 marked this conversation as resolved.
Show resolved Hide resolved

+--------------------------+----------+--------------------------+-------------------------------------------------------+
| Configuration Parameter | Required | Default | Description |
+==========================+==========+==========================+=======================================================+
| **mistral_api_key** | Yes | N/A | This is the API key for the Mistral service. |
+--------------------------+----------+--------------------------+-------------------------------------------------------+


To use key-based authentication, define an AWS Bedrock route with the required fields below.
.. note::

Expand Down
1 change: 1 addition & 0 deletions examples/deployments/deployments_server/README.md
Expand Up @@ -47,6 +47,7 @@ For full examples of configurations and supported endpoint types, see:
- [AI21 Labs](ai21labs/config.yaml)
- [PaLM](palm/config.yaml)
- [AzureOpenAI](azure_openai/config.yaml)
- [Mistral](mistral/config.yaml)

## Step 3: Setting Access Keys

Expand Down
13 changes: 13 additions & 0 deletions examples/deployments/deployments_server/mistral/README.md
@@ -0,0 +1,13 @@
## Example endpoint configuration for Mistral

To see an example of specifying both the completions and the embeddings endpoints for Mistral, see [the configuration](config.yaml) YAML file.

This configuration file specifies two endpoints: 'completions' and 'embeddings', both using Mistral's models 'mistral-tiny' and 'mistral-embed', respectively.

## Setting a Mistral API Key

This example requires a [Mistral API key](https://docs.mistral.ai/):

```sh
export MISTRAL_API_KEY=...
```
16 changes: 16 additions & 0 deletions examples/deployments/deployments_server/mistral/config.yaml
@@ -0,0 +1,16 @@
endpoints:
- name: completions
endpoint_type: llm/v1/completions
model:
provider: mistral
name: mistral-tiny
config:
mistral_api_key: $MISTRAL_API_KEY

- name: embeddings
endpoint_type: llm/v1/embeddings
model:
provider: mistral
name: mistral-embed
config:
mistral_api_key: $MISTRAL_API_KEY
34 changes: 34 additions & 0 deletions examples/deployments/deployments_server/mistral/example.py
@@ -0,0 +1,34 @@
from mlflow.deployments import get_deploy_client


def main():
client = get_deploy_client("http://localhost:7000")

print(f"Mistral endpoints: {client.list_endpoints()}\n")
print(f"Mistral completions endpoint info: {client.get_endpoint(endpoint='completions')}\n")

# Completions request
response_completions = client.predict(
endpoint="completions",
inputs={
"prompt": "How many average size European ferrets can fit inside a standard olympic?",
"temperature": 0.1,
},
)
print(f"Mistral response for completions: {response_completions}")

# Embeddings request
response_embeddings = client.predict(
endpoint="embeddings",
inputs={
"input": [
"How does your culture celebrate the New Year, and how does it differ from other countries’ "
"celebrations?"
]
},
)
print(f"Mistral response for embeddings: {response_embeddings}")


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions mlflow/gateway/config.py
Expand Up @@ -46,6 +46,7 @@ class Provider(str, Enum):
# Note: The following providers are only supported on Databricks
DATABRICKS_MODEL_SERVING = "databricks-model-serving"
DATABRICKS = "databricks"
MISTRAL = "mistral"

@classmethod
def values(cls):
Expand Down Expand Up @@ -215,6 +216,15 @@ class AWSBedrockConfig(ConfigModel):
aws_config: Union[AWSRole, AWSIdAndKey, AWSBaseConfig]


class MistralConfig(ConfigModel):
mistral_api_key: str

# pylint: disable=no-self-argument
@validator("mistral_api_key", pre=True)
def validate_mistral_api_key(cls, value):
return _resolve_api_key_from_input(value)


config_types = {
Provider.COHERE: CohereConfig,
Provider.OPENAI: OpenAIConfig,
Expand All @@ -225,6 +235,7 @@ class AWSBedrockConfig(ConfigModel):
Provider.MLFLOW_MODEL_SERVING: MlflowModelServingConfig,
Provider.PALM: PaLMConfig,
Provider.HUGGINGFACE_TEXT_GENERATION_INFERENCE: HuggingFaceTextGenerationInferenceConfig,
Provider.MISTRAL: MistralConfig,
}


Expand Down Expand Up @@ -284,6 +295,7 @@ class Model(ConfigModel):
MlflowModelServingConfig,
HuggingFaceTextGenerationInferenceConfig,
PaLMConfig,
MistralConfig,
]
] = None

Expand Down
2 changes: 2 additions & 0 deletions mlflow/gateway/providers/__init__.py
Expand Up @@ -11,6 +11,7 @@ def get_provider(provider: Provider) -> Type[BaseProvider]:
from mlflow.gateway.providers.bedrock import AWSBedrockProvider
from mlflow.gateway.providers.cohere import CohereProvider
from mlflow.gateway.providers.huggingface import HFTextGenerationInferenceServerProvider
from mlflow.gateway.providers.mistral import MistralProvider
from mlflow.gateway.providers.mlflow import MlflowModelServingProvider
from mlflow.gateway.providers.mosaicml import MosaicMLProvider
from mlflow.gateway.providers.openai import OpenAIProvider
Expand All @@ -26,6 +27,7 @@ def get_provider(provider: Provider) -> Type[BaseProvider]:
Provider.MLFLOW_MODEL_SERVING: MlflowModelServingProvider,
Provider.HUGGINGFACE_TEXT_GENERATION_INFERENCE: HFTextGenerationInferenceServerProvider,
Provider.BEDROCK: AWSBedrockProvider,
Provider.MISTRAL: MistralProvider,
}
if prov := provider_to_class.get(provider):
return prov
Expand Down
164 changes: 164 additions & 0 deletions mlflow/gateway/providers/mistral.py
@@ -0,0 +1,164 @@
import time
from typing import Any, Dict

from fastapi.encoders import jsonable_encoder

from mlflow.gateway.config import MistralConfig, RouteConfig
from mlflow.gateway.providers.base import BaseProvider, ProviderAdapter
from mlflow.gateway.providers.utils import send_request
from mlflow.gateway.schemas import completions, embeddings


class MistralAdapter(ProviderAdapter):
@classmethod
def model_to_completions(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createChatCompletion)
# ```
# {
# "id": "string",
# "object": "string",
# "created": "integer",
# "model": "string",
# "choices": [
# {
# "index": "integer",
# "message": {
# "role": "string",
# "content": "string"
# },
# "finish_reason": "string",
# }
# ],
# "usage":
# {
# "prompt_tokens": "integer",
# "completion_tokens": "integer",
# "total_tokens": "integer",
# }
# }
# ```
return completions.ResponsePayload(
created=int(time.time()),
object="text_completion",
model=config.model.name,
choices=[
completions.Choice(
index=idx,
text=c["message"]["content"],
finish_reason=c["finish_reason"],
)
for idx, c in enumerate(resp["choices"])
],
usage=completions.CompletionsUsage(
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
total_tokens=resp["usage"]["total_tokens"],
),
)

@classmethod
def model_to_embeddings(cls, resp, config):
# Response example (https://docs.mistral.ai/api/#operation/createEmbedding):
# ```
# {
# "id": "string",
# "object": "string",
# "data": [
# {
# "object": "string",
# "embedding":
# [
# float,
# float
# ]
# "index": "integer",
# }
# ],
# "model": "string",
# "usage":
# {
# "prompt_tokens": "integer",
# "total_tokens": "integer",
# }
# }
# ```
return embeddings.ResponsePayload(
data=[
embeddings.EmbeddingObject(
embedding=data["embedding"],
index=data["index"],
)
for data in resp["data"]
],
model=config.model.name,
usage=embeddings.EmbeddingsUsage(
prompt_tokens=resp["usage"]["prompt_tokens"],
total_tokens=resp["usage"]["total_tokens"],
),
)

@classmethod
def completions_to_model(cls, payload, config):
payload.pop("stop", None)
payload.pop("n", None)
payload["messages"] = [{"role": "user", "content": payload.pop("prompt")}]

# The range of Mistral's temperature is 0-1, but ours is 0-2, so we scale it.
if "temperature" in payload:
payload["temperature"] = 0.5 * payload["temperature"]

return payload

@classmethod
def embeddings_to_model(cls, payload, config):
return payload


class MistralProvider(BaseProvider):
NAME = "Mistral"

def __init__(self, config: RouteConfig) -> None:
super().__init__(config)
if config.model.config is None or not isinstance(config.model.config, MistralConfig):
raise TypeError(f"Unexpected config type {config.model.config}")
self.mistral_config: MistralConfig = config.model.config

@property
def auth_headers(self) -> Dict[str, str]:
return {"Authorization": f"Bearer {self.mistral_config.mistral_api_key}"}

@property
def base_url(self) -> str:
return "https://api.mistral.ai/v1/"

async def _request(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
return await send_request(
headers=self.auth_headers,
base_url=self.base_url,
path=path,
payload=payload,
)

async def completions(self, payload: completions.RequestPayload) -> completions.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
resp = await self._request(
"chat/completions",
{
"model": self.config.model.name,
**MistralAdapter.completions_to_model(payload, self.config),
},
)
return MistralAdapter.model_to_completions(resp, self.config)

async def embeddings(self, payload: embeddings.RequestPayload) -> embeddings.ResponsePayload:
payload = jsonable_encoder(payload, exclude_none=True)
self.check_for_model_field(payload)
resp = await self._request(
"embeddings",
{
"model": self.config.model.name,
**MistralAdapter.embeddings_to_model(payload, self.config),
},
)
return MistralAdapter.model_to_embeddings(resp, self.config)