|
13 | 13 | from dataclasses import dataclass |
14 | 14 | from functools import partial |
15 | 15 | from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
| 16 | +from urllib.parse import urlparse |
16 | 17 |
|
17 | 18 | import fitz |
18 | 19 | import markdown |
|
22 | 23 | from azure.ai.documentintelligence.models import AnalyzeDocumentRequest |
23 | 24 | from azure.ai.inference import EmbeddingsClient |
24 | 25 | from azure.core.credentials import AzureKeyCredential |
| 26 | +from azure.identity import AzureCliCredential |
| 27 | +from azure.keyvault.secrets import SecretClient |
25 | 28 | from azure.storage.blob import ContainerClient |
26 | 29 | from bs4 import BeautifulSoup |
27 | 30 | from dotenv import load_dotenv |
|
34 | 37 | # Configure environment variables |
35 | 38 | load_dotenv() # take environment variables from .env. |
36 | 39 |
|
| 40 | +# Key Vault name - replaced during deployment |
| 41 | +key_vault_name = 'kv_to-be-replaced' |
| 42 | + |
| 43 | + |
| 44 | +def get_secrets_from_kv(secret_name: str) -> str: |
| 45 | + """Retrieves a secret value from Azure Key Vault. |
| 46 | + |
| 47 | + Args: |
| 48 | + secret_name: Name of the secret |
| 49 | + |
| 50 | + Returns: |
| 51 | + The secret value |
| 52 | + """ |
| 53 | + kv_credential = AzureCliCredential() |
| 54 | + secret_client = SecretClient( |
| 55 | + vault_url=f"https://{key_vault_name}.vault.azure.net/", |
| 56 | + credential=kv_credential |
| 57 | + ) |
| 58 | + return secret_client.get_secret(secret_name).value |
| 59 | + |
| 60 | + |
37 | 61 | FILE_FORMAT_DICT = { |
38 | 62 | "md": "markdown", |
39 | 63 | "txt": "text", |
@@ -825,47 +849,33 @@ def get_payload_and_headers_cohere(text, aad_token) -> Tuple[Dict, Dict]: |
825 | 849 | def get_embedding( |
826 | 850 | text, embedding_model_endpoint=None, embedding_model_key=None, azure_credential=None |
827 | 851 | ): |
828 | | - endpoint = ( |
829 | | - embedding_model_endpoint |
830 | | - if embedding_model_endpoint |
831 | | - else os.environ.get("EMBEDDING_MODEL_ENDPOINT") |
832 | | - ) |
833 | | - |
834 | | - FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI") |
835 | | - |
836 | | - if azure_credential is None and (endpoint is None): |
837 | | - raise Exception( |
838 | | - "EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding" |
839 | | - ) |
| 852 | + # Get AI Project endpoint from Key Vault |
| 853 | + ai_project_endpoint = get_secrets_from_kv("AZURE-AI-AGENT-ENDPOINT") |
| 854 | + |
| 855 | + # Construct inference endpoint: https://aif-xyz.services.ai.azure.com/models |
| 856 | + inference_endpoint = f"https://{urlparse(ai_project_endpoint).netloc}/models" |
| 857 | + embedding_model = "text-embedding-ada-002" |
840 | 858 |
|
841 | 859 | try: |
842 | | - if FLAG_EMBEDDING_MODEL == "AOAI": |
843 | | - deployment_id = "embedding" |
844 | | - |
845 | | - if azure_credential is not None: |
846 | | - # Use managed identity credential with credential_scopes parameter |
847 | | - client = EmbeddingsClient( |
848 | | - endpoint=f"{endpoint}/openai/deployments/{deployment_id}", |
849 | | - credential=azure_credential, |
850 | | - credential_scopes=["https://cognitiveservices.azure.com/.default"] |
851 | | - ) |
852 | | - else: |
853 | | - # Use API key credential |
854 | | - api_key = ( |
855 | | - embedding_model_key |
856 | | - if embedding_model_key |
857 | | - else os.getenv("AZURE_OPENAI_API_KEY") |
858 | | - ) |
859 | | - client = EmbeddingsClient( |
860 | | - endpoint=f"{endpoint}/openai/deployments/{deployment_id}", |
861 | | - credential=AzureKeyCredential(api_key) |
862 | | - ) |
863 | | - response = client.embed(input=[text]) |
864 | | - return response.data[0].embedding |
| 860 | + if azure_credential is not None: |
| 861 | + embeddings_client = EmbeddingsClient( |
| 862 | + endpoint=inference_endpoint, |
| 863 | + credential=azure_credential, |
| 864 | + credential_scopes=["https://cognitiveservices.azure.com/.default"] |
| 865 | + ) |
| 866 | + else: |
| 867 | + api_key = embedding_model_key or os.getenv("AZURE_OPENAI_API_KEY") |
| 868 | + embeddings_client = EmbeddingsClient( |
| 869 | + endpoint=inference_endpoint, |
| 870 | + credential=AzureKeyCredential(api_key) |
| 871 | + ) |
| 872 | + |
| 873 | + response = embeddings_client.embed(model=embedding_model, input=[text]) |
| 874 | + return response.data[0].embedding |
865 | 875 |
|
866 | 876 | except Exception as e: |
867 | 877 | raise Exception( |
868 | | - f"Error getting embeddings with endpoint={endpoint} with error={e}" |
| 878 | + f"Error getting embeddings with endpoint={inference_endpoint} with error={e}" |
869 | 879 | ) |
870 | 880 |
|
871 | 881 |
|
|
0 commit comments