New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate AI gateway #10420
Merged
Merged
Migrate AI gateway #10420
Changes from 10 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
f536157
Remove unused loggger
harupy dce3cb6
Add databricks deployments client skeleton + example (#10421)
harupy 856de75
Set up CI (#10422)
harupy 94ba4d2
Change git labels from gateway -> deployments (#10428)
dbczumar e450b10
Update gateway embeddings request / response format (#10430)
dbczumar 3fd534f
Fix typo
harupy ee8daca
Fix typo
harupy c32b9a8
Add `MLflowDeploymentClient` (still empty) (#10447)
harupy 8566e3f
Implement CRUD for serving-endpoints (#10425)
harupy d220673
Use `Literal` for string constants (#10457)
harupy 2845e82
Update gateway chat request / response format (#10454)
dbczumar 1390a2d
Update gateway completions request / response format (#10465)
dbczumar ac98a64
Deprecation warning for `mlflow.gateway` (#10460)
harupy 12f2279
Mark deployment clients as experimental (#10468)
harupy 9f1d547
[1] Copy gateway server CLI to mlflow deployments start-server (#10426)
dbczumar d53c558
[2] Add /endpoints APIs to MLflow Deployments server (#10466)
dbczumar 5ccb7f3
Fix `start_server` (#10470)
harupy 9fcb238
[3] Support "endpoints" in deployments server YAML conf (#10467)
dbczumar 1877ef1
Suppress Pydantic warnings (#10483)
B-Step62 fdb2779
Implement `MLflowDeploymentClient` (#10458)
harupy 8ca8429
Set timeout default to `None` (#10487)
harupy 6f97551
Rename `MLflowDeploymentClient` to `MlflowDeploymentClient` (#10500)
harupy 5472a35
Fix `MlflowDeploymentClient.list_endpoints` to auto-paginate (#10501)
harupy 9fc93d3
Use `/rate-limits` if `config` only contains `rate_limits` (#10502)
harupy bf0782c
Replace `TODO` in `MlflowDeploymentClient` (#10496)
harupy 787d5f0
Replace `TODO` in `DatabricksDeploymentClient` (#10499)
harupy cbb627e
Add `genai` extra (#10516)
harupy 4caba21
Fix FastAPI docs (#10518)
harupy 58972f7
Deployments server docs (#10517)
harupy cd0d458
[FEAT] Migrating prompt lab to use MLflow deployments (#10515)
sunishsheth2009 ebf1c29
Fix request and response schema (#10523)
harupy 86ec60d
Remove fluent API (#10524)
harupy faff091
Fix getting-started guide (#10522)
harupy 6382432
Support `endpoints:/my-endpoint` in LLM-as-judge metrics (#10528)
prithvikannan ef013ae
Fix schema (#10541)
harupy cff30c9
Fix client API section (#10542)
harupy a0b30e5
Fix refs (#10544)
harupy 2e9f299
Use `endpoints` and `endpoint_type` (#10545)
harupy 9ac5337
Fix REST examples (#10547)
harupy 8b36901
Fix langchain example (#10548)
harupy 4c9a908
Update `docs/source/llms/deployments/index.rst` (#10550)
harupy 4602e81
Update `docs/source/llms/index.rst` (#10551)
harupy 54751eb
Use `MLFLOW_DEPLOYMENTS_TARGET` in `gateway_proxy_handler` (#10554)
harupy c8dfc06
Update `docs/source/llms/prompt-engineering/index.rst` (#10552)
harupy d2945a9
Quick fix for promptlab gateway migration (#10563)
daniellok-db 631ce09
Fix completions params (#10565)
harupy ea341fd
Fix incorrect Embeddings param: `text` -> `input` (#10566)
harupy 2f0a1fc
Replace `candidates` with `choices` (#10569)
harupy 1861d5b
More gateway replacements (#10568)
harupy beaf89f
[Docs] Updating the docs for prompt engineering with MLflow deploymen…
sunishsheth2009 22e791f
Gateway anthropic small fix for "n" parameter (#10576)
dbczumar 0998384
Fix example links (#10561)
harupy 96a6d2b
[Bug-fix] Fixing the error state bug for Mlflow deployments (#10575)
sunishsheth2009 77252fe
Update examples to use MLflow Deployments (#10558)
dbczumar e502ddf
Fix deployments examples (#10582)
harupy 7e99592
Support completions endpoints (#10577)
prithvikannan 4ec83b0
merge master
prithvikannan abb76c9
Update `docs/source/llms/gateway/migration.rst` (#10498)
harupy db1e9e8
Update quickstart quides (#10593)
harupy File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
name: Deployments | ||
|
||
on: | ||
pull_request: | ||
push: | ||
branches: | ||
- master | ||
- branch-[0-9]+.[0-9]+ | ||
|
||
permissions: | ||
contents: read | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
defaults: | ||
run: | ||
shell: bash --noprofile --norc -exo pipefail {0} | ||
|
||
jobs: | ||
deployments: | ||
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false | ||
runs-on: ubuntu-latest | ||
timeout-minutes: 30 | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: ./.github/actions/untracked | ||
- uses: ./.github/actions/setup-python | ||
- name: Install dependencies | ||
run: | | ||
pip install .[gateway] \ | ||
pytest pytest-timeout pytest-asyncio httpx psutil | ||
- name: Run tests | ||
run: | | ||
pytest tests/deployments/databricks tests/deployments/mlflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Usage | ||
----- | ||
databricks secrets create-scope <scope> | ||
databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY | ||
python examples/deployments/databricks.py --secret <scope>/openai-api-key | ||
----- | ||
""" | ||
import argparse | ||
import uuid | ||
|
||
from mlflow.deployments import get_deploy_client | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--secret", type=str, help="Secret (e.g. secrets/scope/key)") | ||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
client = get_deploy_client("databricks") | ||
name = f"test-endpoint-{uuid.uuid4()}" | ||
client.create_endpoint( | ||
name=name, | ||
config={ | ||
"served_entities": [ | ||
{ | ||
"name": "test", | ||
"external_model": { | ||
"name": "gpt-4", | ||
"provider": "openai", | ||
"task": "llm/v1/chat", | ||
"openai_config": { | ||
"openai_api_key": "{{" + args.secret + "}}", | ||
}, | ||
}, | ||
} | ||
], | ||
"tags": [ | ||
{ | ||
"key": "foo", | ||
"value": "bar", | ||
} | ||
], | ||
"rate_limits": [ | ||
{ | ||
"key": "user", | ||
"renewal_period": "minute", | ||
"calls": 5, | ||
} | ||
], | ||
}, | ||
) | ||
try: | ||
client.update_endpoint( | ||
endpoint=name, | ||
config={ | ||
"served_entities": [ | ||
{ | ||
"name": "test", | ||
"external_model": { | ||
"name": "gpt-4", | ||
"provider": "openai", | ||
"task": "llm/v1/chat", | ||
"openai_config": { | ||
"openai_api_key": "{{" + args.secret + "}}", | ||
}, | ||
}, | ||
} | ||
], | ||
"rate_limits": [ | ||
{ | ||
"key": "user", | ||
"renewal_period": "minute", | ||
"calls": 10, | ||
} | ||
], | ||
}, | ||
) | ||
print(client.list_endpoints()[:3]) | ||
print(client.get_endpoint(endpoint=name)) | ||
print( | ||
client.predict( | ||
endpoint=name, | ||
inputs={ | ||
"messages": [ | ||
{"role": "user", "content": "Hello!"}, | ||
], | ||
"max_tokens": 128, | ||
}, | ||
), | ||
) | ||
finally: | ||
client.delete_endpoint(endpoint=name) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from mlflow.environment_variables import _EnvironmentVariable | ||
|
||
# TODO: Move this to mlflow.environment_variables before merging to master | ||
# Specifies the timeout for deployment client APIs to declare a request has timed out | ||
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT = _EnvironmentVariable( | ||
"MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT", int, 120 | ||
) | ||
|
||
# Abridged retryable error codes for deployments clients. | ||
# These are modified from the standard MLflow Tracking server retry codes for the MLflowClient to | ||
# remove timeouts from the list of the retryable conditions. A long-running timeout with | ||
# retries for the proxied providers generally indicates an issue with the underlying query or | ||
# the model being served having issues responding to the query due to parameter configuration. | ||
MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES = frozenset( | ||
[ | ||
429, # Too many requests | ||
500, # Server Error | ||
502, # Bad Gateway | ||
503, # Service Unavailable | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import posixpath | ||
from typing import Any, Dict, Optional | ||
|
||
from mlflow.deployments import BaseDeploymentClient | ||
from mlflow.deployments.constants import ( | ||
MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES, | ||
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT, | ||
) | ||
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_TIMEOUT | ||
from mlflow.utils import AttrDict | ||
from mlflow.utils.databricks_utils import get_databricks_host_creds | ||
from mlflow.utils.rest_utils import augmented_raise_for_status, http_request | ||
|
||
|
||
class DatabricksEndpoint(AttrDict): | ||
pass | ||
|
||
|
||
class DatabricksDeploymentClient(BaseDeploymentClient): | ||
""" | ||
TODO | ||
""" | ||
|
||
def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None): | ||
""" | ||
.. warning:: | ||
|
||
This method is not implemented for `DatabricksDeploymentClient`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None): | ||
""" | ||
.. warning:: | ||
|
||
This method is not implemented for `DatabricksDeploymentClient`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def delete_deployment(self, name, config=None, endpoint=None): | ||
""" | ||
.. warning:: | ||
|
||
This method is not implemented for `DatabricksDeploymentClient`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def list_deployments(self, endpoint=None): | ||
""" | ||
.. warning:: | ||
|
||
This method is not implemented for `DatabricksDeploymentClient`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_deployment(self, name, endpoint=None): | ||
""" | ||
.. warning:: | ||
|
||
This method is not implemented for `DatabricksDeploymentClient`. | ||
""" | ||
raise NotImplementedError | ||
|
||
def _call_endpoint( | ||
self, | ||
*, | ||
method: str, | ||
prefix: str = "/api/2.0", | ||
route: Optional[str] = None, | ||
json_body: Optional[Dict[str, Any]] = None, | ||
timeout: int = MLFLOW_HTTP_REQUEST_TIMEOUT.get(), | ||
): | ||
call_kwargs = {} | ||
if method.lower() == "get": | ||
call_kwargs["params"] = json_body | ||
else: | ||
call_kwargs["json"] = json_body | ||
|
||
response = http_request( | ||
host_creds=get_databricks_host_creds(self.target_uri), | ||
endpoint=posixpath.join(prefix, "serving-endpoints", route or ""), | ||
method=method, | ||
timeout=timeout, | ||
raise_on_status=False, | ||
retry_codes=MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES, | ||
**call_kwargs, | ||
) | ||
augmented_raise_for_status(response) | ||
return DatabricksEndpoint(response.json()) | ||
|
||
def predict(self, deployment_name=None, inputs=None, endpoint=None): | ||
""" | ||
TODO | ||
""" | ||
return self._call_endpoint( | ||
method="POST", | ||
prefix="/", | ||
route=posixpath.join(endpoint, "invocations"), | ||
json_body=inputs, | ||
timeout=MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get(), | ||
) | ||
|
||
def create_endpoint(self, name, config=None): | ||
""" | ||
TODO | ||
""" | ||
config = config.copy() if config else {} # avoid mutating config | ||
extras = {} | ||
for key in ("tags", "rate_limits"): | ||
if tags := config.pop(key, None): | ||
extras[key] = tags | ||
payload = {"name": name, "config": config, **extras} | ||
return self._call_endpoint(method="POST", json_body=payload) | ||
|
||
def update_endpoint(self, endpoint, config=None): | ||
""" | ||
TODO | ||
""" | ||
return self._call_endpoint( | ||
method="PUT", route=posixpath.join(endpoint, "config"), json_body=config | ||
) | ||
|
||
def delete_endpoint(self, endpoint): | ||
""" | ||
TODO | ||
""" | ||
return self._call_endpoint(method="DELETE", route=endpoint) | ||
|
||
def list_endpoints(self): | ||
""" | ||
TODO | ||
""" | ||
return self._call_endpoint(method="GET").endpoints | ||
|
||
def get_endpoint(self, endpoint): | ||
""" | ||
TODO | ||
""" | ||
return self._call_endpoint(method="GET", route=endpoint) | ||
|
||
|
||
def run_local(name, model_uri, flavor=None, config=None): | ||
pass | ||
|
||
|
||
def target_help(): | ||
pass |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reminder