-
Notifications
You must be signed in to change notification settings - Fork 2k
/
common.ts
236 lines (206 loc) Β· 6.42 KB
/
common.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import { BaseLLM } from "@langchain/core/language_models/llms";
import {
Generation,
GenerationChunk,
LLMResult,
} from "@langchain/core/outputs";
import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
GoogleVertexAILLMConnection,
GoogleVertexAIStream,
GoogleVertexAILLMResponse,
} from "../../utils/googlevertexai-connection.js";
import {
GoogleVertexAIBaseLLMInput,
GoogleVertexAIBasePrediction,
GoogleVertexAILLMPredictions,
GoogleVertexAIModelParams,
} from "../../types/googlevertexai-types.js";
/**
* Interface representing the instance of text input to the Google Vertex
* AI model.
*/
interface GoogleVertexAILLMTextInstance {
content: string;
}
/**
* Interface representing the instance of code input to the Google Vertex
* AI model.
*/
interface GoogleVertexAILLMCodeInstance {
prefix: string;
}
/**
* Type representing an instance of either text or code input to the
* Google Vertex AI model.
*/
type GoogleVertexAILLMInstance =
| GoogleVertexAILLMTextInstance
| GoogleVertexAILLMCodeInstance;
/**
* Models the data returned from the API call
*/
interface TextPrediction extends GoogleVertexAIBasePrediction {
content: string;
}
/**
* Base class for Google Vertex AI LLMs.
* Implemented subclasses must provide a GoogleVertexAILLMConnection
* with an appropriate auth client.
*/
export class BaseGoogleVertexAI<AuthOptions>
extends BaseLLM
implements GoogleVertexAIBaseLLMInput<AuthOptions>
{
lc_serializable = true;
model = "text-bison";
temperature = 0.7;
maxOutputTokens = 1024;
topP = 0.8;
topK = 40;
protected connection: GoogleVertexAILLMConnection<
BaseLanguageModelCallOptions,
GoogleVertexAILLMInstance,
TextPrediction,
AuthOptions
>;
protected streamedConnection: GoogleVertexAILLMConnection<
BaseLanguageModelCallOptions,
GoogleVertexAILLMInstance,
TextPrediction,
AuthOptions
>;
get lc_aliases(): Record<string, string> {
return {
model: "model_name",
};
}
constructor(fields?: GoogleVertexAIBaseLLMInput<AuthOptions>) {
super(fields ?? {});
this.model = fields?.model ?? this.model;
// Change the defaults for code models
if (this.model.startsWith("code-gecko")) {
this.maxOutputTokens = 64;
}
if (this.model.startsWith("code-")) {
this.temperature = 0.2;
}
this.temperature = fields?.temperature ?? this.temperature;
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens;
this.topP = fields?.topP ?? this.topP;
this.topK = fields?.topK ?? this.topK;
}
_llmType(): string {
return "vertexai";
}
async *_streamResponseChunks(
_input: string,
_options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
// Make the call as a streaming request
const instance = this.formatInstance(_input);
const parameters = this.formatParameters();
const result = await this.streamedConnection.request(
[instance],
parameters,
_options
);
// Get the streaming parser of the response
const stream = result.data as GoogleVertexAIStream;
// Loop until the end of the stream
// During the loop, yield each time we get a chunk from the streaming parser
// that is either available or added to the queue
while (!stream.streamDone) {
const output = await stream.nextChunk();
const chunk =
output !== null
? new GenerationChunk(
this.extractGenerationFromPrediction(output.outputs[0])
)
: new GenerationChunk({
text: "",
generationInfo: { finishReason: "stop" },
});
yield chunk;
}
}
async _generate(
prompts: string[],
options: this["ParsedCallOptions"]
): Promise<LLMResult> {
const generations: Generation[][] = await Promise.all(
prompts.map((prompt) => this._generatePrompt(prompt, options))
);
return { generations };
}
async _generatePrompt(
prompt: string,
options: this["ParsedCallOptions"]
): Promise<Generation[]> {
const instance = this.formatInstance(prompt);
const parameters = this.formatParameters();
const result = await this.connection.request(
[instance],
parameters,
options
);
const prediction = this.extractPredictionFromResponse(result);
return [this.extractGenerationFromPrediction(prediction)];
}
/**
* Formats the input instance as a text instance for the Google Vertex AI
* model.
* @param prompt Prompt to be formatted as a text instance.
* @returns A GoogleVertexAILLMInstance object representing the formatted text instance.
*/
formatInstanceText(prompt: string): GoogleVertexAILLMInstance {
return { content: prompt };
}
/**
* Formats the input instance as a code instance for the Google Vertex AI
* model.
* @param prompt Prompt to be formatted as a code instance.
* @returns A GoogleVertexAILLMInstance object representing the formatted code instance.
*/
formatInstanceCode(prompt: string): GoogleVertexAILLMInstance {
return { prefix: prompt };
}
/**
* Formats the input instance for the Google Vertex AI model based on the
* model type (text or code).
* @param prompt Prompt to be formatted as an instance.
* @returns A GoogleVertexAILLMInstance object representing the formatted instance.
*/
formatInstance(prompt: string): GoogleVertexAILLMInstance {
return this.model.startsWith("code-")
? this.formatInstanceCode(prompt)
: this.formatInstanceText(prompt);
}
formatParameters(): GoogleVertexAIModelParams {
return {
temperature: this.temperature,
topK: this.topK,
topP: this.topP,
maxOutputTokens: this.maxOutputTokens,
};
}
/**
* Extracts the prediction from the API response.
* @param result The API response from which to extract the prediction.
* @returns A TextPrediction object representing the extracted prediction.
*/
extractPredictionFromResponse(
result: GoogleVertexAILLMResponse<TextPrediction>
): TextPrediction {
return (result?.data as GoogleVertexAILLMPredictions<TextPrediction>)
?.predictions[0];
}
extractGenerationFromPrediction(prediction: TextPrediction): Generation {
return {
text: prediction.content,
generationInfo: prediction,
};
}
}