-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
zhipuai.ts
129 lines (110 loc) Β· 3.66 KB
/
zhipuai.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { encodeApiKey } from "../utils/zhipuai.js";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the ZhipuAIEmbeddingsParams class.
*/
export interface ZhipuAIEmbeddingsParams extends EmbeddingsParams {
/**
* Model Name to use
*/
modelName?: "embedding-2";
/**
* ZhipuAI API key to use
*/
apiKey?: string;
/**
* Whether to strip new lines from the input text.
*/
stripNewLines?: boolean;
}
interface EmbeddingData {
embedding: number[];
index: number;
object: string;
}
interface TokenUsage {
completion_tokens: number;
prompt_tokens: number;
total_tokens: number;
}
export interface ZhipuAIEmbeddingsResult {
model: string;
data: EmbeddingData[];
object: string;
usage: TokenUsage;
}
export class ZhipuAIEmbeddings
extends Embeddings
implements ZhipuAIEmbeddingsParams
{
modelName: ZhipuAIEmbeddingsParams["modelName"] = "embedding-2";
apiKey?: string;
stripNewLines = true;
private embeddingsAPIURL = "https://open.bigmodel.cn/api/paas/v4/embeddings";
constructor(fields?: ZhipuAIEmbeddingsParams) {
super(fields ?? {});
this.modelName = fields?.modelName ?? this.modelName;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.apiKey = fields?.apiKey ?? getEnvironmentVariable("ZHIPUAI_API_KEY");
if (!this.apiKey) {
throw new Error("ZhipuAI API key not found");
}
}
/**
* Private method to make a request to the TogetherAI API to generate
* embeddings. Handles the retry logic and returns the response from the API.
* @param {string} input The input text to embed.
* @returns Promise that resolves to the response from the API.
* @TODO Figure out return type and statically type it.
*/
private async embeddingWithRetry(
input: string
): Promise<ZhipuAIEmbeddingsResult> {
const text = this.stripNewLines ? input.replace(/\n/g, " ") : input;
const body = JSON.stringify({ input: text, model: this.modelName });
const headers = {
Accept: "application/json",
"Content-Type": "application/json",
Authorization: encodeApiKey(this.apiKey),
};
return this.caller.call(async () => {
const fetchResponse = await fetch(this.embeddingsAPIURL, {
method: "POST",
headers,
body,
});
if (fetchResponse.status === 200) {
return fetchResponse.json();
}
throw new Error(
`Error getting embeddings from ZhipuAI. ${JSON.stringify(
await fetchResponse.json(),
null,
2
)}`
);
});
}
/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param {string} text Document to generate an embedding for.
* @returns {Promise<number[]>} Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const { data } = await this.embeddingWithRetry(text);
return data[0].embedding;
}
/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the embedQuery method for each document in the array.
* @param documents Array of documents for which to generate embeddings.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
return Promise.all(documents.map((doc) => this.embedQuery(doc)));
}
}