-
Notifications
You must be signed in to change notification settings - Fork 2k
/
llms.ts
111 lines (97 loc) · 3.22 KB
/
llms.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai";
import { type BaseLLMParams } from "@langchain/core/language_models/llms";
import { OpenAI } from "../llms.js";
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js";
import type {
OpenAIInput,
AzureOpenAIInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
} from "../types.js";
export class AzureOpenAI extends OpenAI {
get lc_aliases(): Record<string, string> {
return {
openAIApiKey: "openai_api_key",
openAIApiVersion: "openai_api_version",
openAIBasePath: "openai_api_base",
};
}
constructor(
fields?: Partial<OpenAIInput> & {
openAIApiKey?: string;
openAIApiVersion?: string;
openAIBasePath?: string;
deploymentName?: string;
} & Partial<AzureOpenAIInput> &
BaseLLMParams & {
configuration?: ClientOptions & LegacyOpenAIInput;
}
) {
const newFields = fields ? { ...fields } : fields;
if (newFields) {
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.openAIApiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
}
super(newFields);
}
protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};
const endpoint = getEndpoint(openAIEndpointConfig);
const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};
if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}
if (!params.baseURL) {
delete params.baseURL;
}
this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
...params,
});
}
const requestOptions = {
...this.clientConfig,
...options,
} as OpenAICoreRequestOptions;
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): any {
const json = super.toJSON() as unknown;
function isRecord(obj: unknown): obj is Record<string, unknown> {
return typeof obj === "object" && obj != null;
}
if (isRecord(json) && isRecord(json.kwargs)) {
delete json.kwargs.azure_openai_base_path;
delete json.kwargs.azure_openai_api_deployment_name;
delete json.kwargs.azure_openai_api_key;
delete json.kwargs.azure_openai_api_version;
delete json.kwargs.azure_open_ai_base_path;
}
return json;
}
}