Skip to content

Commit

Permalink
fix OpenAIEmbeddings to support Azure OpenAI Service custom endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
zioproto committed Apr 29, 2023
1 parent 32793f9 commit 279a59b
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(
deployment="your-embeddings-deployment-name",
model="your-embeddings-model-name"
model="your-embeddings-model-name",
api_base="https://your-endpoint.openai.azure.com/",
api_type="azure",
)
text = "This is a test query."
query_result = embeddings.embed_query(text)
Expand All @@ -104,6 +106,13 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model # to support Azure OpenAI Service custom deployment names
openai_api_version: str = "2022-12-01"
openai_api_base: Optional[
str
] = None # to support Azure OpenAI Service custom endpoints
openai_api_type: Optional[
str
] = None # to support Azure OpenAI Service custom endpoints
embedding_ctx_length: int = 8191
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
Expand All @@ -125,6 +134,23 @@ def validate_environment(cls, values: Dict) -> Dict:
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_api_base = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
openai_api_type = get_from_dict_or_env(
values,
"openai_api_type",
"OPENAI_API_TYPE",
default="",
)
openai_api_version = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
)
openai_organization = get_from_dict_or_env(
values,
"openai_organization",
Expand All @@ -137,6 +163,11 @@ def validate_environment(cls, values: Dict) -> Dict:
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
if openai_api_base:
openai.api_base = openai_api_base
openai.api_version = openai_api_version
if openai_api_type:
openai.api_type = openai_api_type
values["client"] = openai.Embedding
except ImportError:
raise ValueError(
Expand Down

0 comments on commit 279a59b

Please sign in to comment.