Skip to content

Commit

Permalink
Fix azure openai and docs (#10894)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed Jan 26, 2024
1 parent ebb8b2d commit 628fba4
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 42 deletions.
2 changes: 1 addition & 1 deletion docs/source/llms/openai/guide/index.rst
Expand Up @@ -109,7 +109,7 @@ To successfully log a model targeting Azure OpenAI Service, specific environment
- **OPENAI_API_BASE**: The base endpoint for your Azure OpenAI resource (e.g., ``https://<your-service-name>.openai.azure.com/``). Within the Azure OpenAI documentation and guides, this key is referred to as ``AZURE_OPENAI_ENDPOINT`` or simply ``ENDPOINT``.
- **OPENAI_API_VERSION**: The API version to use for the Azure OpenAI Service. More information can be found in the `Azure OpenAI documentation <https://learn.microsoft.com/en-us/azure/ai-services/openai/reference>`_, including up-to-date lists of supported versions.
- **OPENAI_API_TYPE**: If using Azure OpenAI endpoints, this value should be set to ``"azure"``.
- **DEPLOYMENT_ID**: The deployment name that you chose when you deployed the model in Azure. To learn more, visit the `Azure OpenAI deployment documentation <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal>`_.
- **OPENAI_DEPLOYMENT_NAME**: The deployment name that you chose when you deployed the model in Azure. To learn more, visit the `Azure OpenAI deployment documentation <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal>`_.

Azure OpenAI Service in MLflow
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
5 changes: 1 addition & 4 deletions examples/openai/azure_openai.py
@@ -1,5 +1,3 @@
import os

import openai
import pandas as pd

Expand All @@ -14,7 +12,7 @@
# OPENAI_API_VERSION e.g. 2023-05-15
export OPENAI_API_VERSION="<AZURE OPENAI API VERSION>"
export OPENAI_API_TYPE="azure"
export DEPLOYMENT_ID="<AZURE OPENAI DEPLOYMENT ID OR NAME>"
export OPENAI_DEPLOYMENT_NAME="<AZURE OPENAI DEPLOYMENT ID OR NAME>"
"""

with mlflow.start_run():
Expand All @@ -24,7 +22,6 @@
task=openai.ChatCompletion,
artifact_path="model",
messages=[{"role": "user", "content": "Tell me a joke about {animal}."}],
deployment_id=os.environ["DEPLOYMENT_ID"],
)

# Load native OpenAI model
Expand Down
27 changes: 14 additions & 13 deletions mlflow/openai/__init__.py
Expand Up @@ -653,15 +653,16 @@ def __init__(self, model):
self.api_token = _OAITokenHolder(self.api_config.api_type)
# If the same parameter exists in self.model & self.api_config,
# we use the parameter from self.model
self.envs = {
x: getattr(self.api_config, x)
for x in ["api_base", "api_version", "api_type", "engine", "deployment_id"]
if getattr(self.api_config, x) is not None and x not in self.model
}
api_type = self.model.get("api_type") or self.envs.get("api_type")
if api_type in ("azure", "azure_ad", "azuread"):
deployment_id = self.model.get("deployment_id") or self.envs.get("deployment_id")
if self.model.get("engine") or self.envs.get("engine"):
self.request_configs = {}
for x in ["api_base", "api_version", "api_type", "engine", "deployment_id"]:
if x in self.model:
self.request_configs[x] = self.model.pop(x)
elif value := getattr(self.api_config, x):
self.request_configs[x] = value

if self.request_configs.get("api_type") in ("azure", "azure_ad", "azuread"):
deployment_id = self.request_configs.get("deployment_id")
if self.request_configs.get("engine"):
# Avoid using both parameters as they serve the same purpose
# Invalid inputs:
# - Wrong engine + correct/wrong deployment_id
Expand Down Expand Up @@ -704,11 +705,11 @@ def get_params_list(self, data):
return data[self.formater.variables].to_dict(orient="records")

def _construct_request_url(self, task_url, default_url):
api_type = self.model.get("api_type") or self.envs.get("api_type")
api_type = self.request_configs.get("api_type")
if api_type in ("azure", "azure_ad", "azuread"):
api_base = self.envs.get("api_base")
api_version = self.envs.get("api_version")
deployment_id = self.envs.get("deployment_id")
api_base = self.request_configs.get("api_base")
api_version = self.request_configs.get("api_version")
deployment_id = self.request_configs.get("deployment_id")

return (
f"{api_base}/openai/deployments/{deployment_id}/"
Expand Down
5 changes: 0 additions & 5 deletions mlflow/utils/openai_utils.py
Expand Up @@ -211,11 +211,6 @@ def _validate_model_params(task, model, params):
)


def _exclude_params_from_envs(params, envs):
"""Params passed at inference time should override envs."""
return {k: v for k, v in envs.items() if k not in params} if params else envs


class _OAITokenHolder:
def __init__(self, api_type):
self._api_token = None
Expand Down
19 changes: 0 additions & 19 deletions tests/openai/test_openai_model_export.py
@@ -1,6 +1,5 @@
import importlib
import json
from copy import deepcopy
from unittest import mock

import numpy as np
Expand All @@ -17,7 +16,6 @@
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, ParamSchema, ParamSpec, Schema, TensorSpec
from mlflow.utils.openai_utils import (
_exclude_params_from_envs,
_mock_chat_completion_response,
_mock_models_retrieve_response,
_mock_request,
Expand Down Expand Up @@ -668,20 +666,3 @@ def test_engine_and_deployment_id_for_azure_openai(tmp_path, monkeypatch):
MlflowException, match=r"Either engine or deployment_id must be set for Azure OpenAI API"
):
mlflow.pyfunc.load_model(tmp_path)


@pytest.mark.parametrize(
("params", "envs"),
[
({"a": None, "b": "b"}, {"a": "a", "c": "c"}),
({"a": "a", "b": "b"}, {"a": "a", "d": "d"}),
({}, {"a": "a", "b": "b"}),
({"a": "a"}, {"b": "b"}),
],
)
def test_exclude_params_from_envs(params, envs):
original_envs = deepcopy(envs)
result = _exclude_params_from_envs(params, envs)
assert envs == original_envs
assert not any(key in params for key in result)
assert all(key in envs for key in result)

0 comments on commit 628fba4

Please sign in to comment.