-
Notifications
You must be signed in to change notification settings - Fork 3k
/
azure_chat_prompt_execution_settings.py
102 lines (71 loc) · 3.59 KB
/
azure_chat_prompt_execution_settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import logging
from typing import Annotated, Any, Literal, Union
from pydantic import AliasGenerator, ConfigDict, Field
from pydantic.alias_generators import to_camel, to_snake
from pydantic.functional_validators import AfterValidator
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.kernel_pydantic import KernelBaseModel
logger = logging.getLogger(__name__)
class AzureChatRequestBase(KernelBaseModel):
model_config = ConfigDict(
alias_generator=AliasGenerator(validation_alias=to_camel, serialization_alias=to_snake),
use_enum_values=True,
extra="allow",
)
class ConnectionStringAuthentication(AzureChatRequestBase):
type: Annotated[Literal["ConnectionString", "connection_string"], AfterValidator(to_snake)] = "connection_string"
connection_string: str | None = None
class ApiKeyAuthentication(AzureChatRequestBase):
type: Annotated[Literal["APIKey", "api_key"], AfterValidator(to_snake)] = "api_key"
key: str | None = None
class AzureEmbeddingDependency(AzureChatRequestBase):
type: Annotated[Literal["DeploymentName", "deployment_name"], AfterValidator(to_snake)] = "deployment_name"
deployment_name: str | None = None
class DataSourceFieldsMapping(AzureChatRequestBase):
title_field: str | None = None
url_field: str | None = None
filepath_field: str | None = None
content_fields: list[str] | None = None
vector_fields: list[str] | None = None
content_fields_separator: str | None = "\n"
class AzureDataSourceParameters(AzureChatRequestBase):
index_name: str
index_language: str | None = None
fields_mapping: DataSourceFieldsMapping | None = None
in_scope: bool | None = True
top_n_documents: int | None = 5
semantic_configuration: str | None = None
role_information: str | None = None
filter: str | None = None
strictness: int = 3
embedding_dependency: AzureEmbeddingDependency | None = None
class AzureCosmosDBDataSourceParameters(AzureDataSourceParameters):
authentication: ConnectionStringAuthentication | None = None
database_name: str | None = None
container_name: str | None = None
embedding_dependency_type: AzureEmbeddingDependency | None = None
class AzureCosmosDBDataSource(AzureChatRequestBase):
type: Literal["azure_cosmos_db"] = "azure_cosmos_db"
parameters: AzureCosmosDBDataSourceParameters
class AzureAISearchDataSourceParameters(AzureDataSourceParameters):
endpoint: str | None = None
query_type: Annotated[
Literal["simple", "semantic", "vector", "vectorSimpleHybrid", "vectorSemanticHybrid"], AfterValidator(to_snake)
] = "simple"
authentication: ApiKeyAuthentication | None = None
class AzureAISearchDataSource(AzureChatRequestBase):
type: Literal["azure_search"] = "azure_search"
parameters: Annotated[dict, AzureAISearchDataSourceParameters]
DataSource = Annotated[Union[AzureAISearchDataSource, AzureCosmosDBDataSource], Field(discriminator="type")]
class ExtraBody(KernelBaseModel):
data_sources: list[DataSource] | None = None
input_language: str | None = Field(None, serialization_alias="inputLanguage")
output_language: str | None = Field(None, serialization_alias="outputLanguage")
def __getitem__(self, item):
return getattr(self, item)
class AzureChatPromptExecutionSettings(OpenAIChatPromptExecutionSettings):
"""Specific settings for the Azure OpenAI Chat Completion endpoint."""
response_format: str | None = None
extra_body: dict[str, Any] | ExtraBody | None = None