-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
togetherai.ts
197 lines (172 loc) Β· 5.53 KB
/
togetherai.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
/**
* Interface for TogetherAIEmbeddingsParams parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the TogetherAIEmbeddings class.
*/
export interface TogetherAIEmbeddingsParams extends EmbeddingsParams {
/**
* The API key to use for the TogetherAI API.
* @default {process.env.TOGETHER_AI_API_KEY}
*/
apiKey?: string;
/**
* Model name to use
* Alias for `model`
* @default {"togethercomputer/m2-bert-80M-8k-retrieval"}
*/
modelName?: string;
/**
* Model name to use
* @default {"togethercomputer/m2-bert-80M-8k-retrieval"}
*/
model?: string;
/**
* Timeout to use when making requests to TogetherAI.
* @default {undefined}
*/
timeout?: number;
/**
* The maximum number of documents to embed in a single request.
* @default {512}
*/
batchSize?: number;
/**
* Whether to strip new lines from the input text. May not be suitable
* for all use cases.
* @default {false}
*/
stripNewLines?: boolean;
}
/** @ignore */
interface TogetherAIEmbeddingsResult {
object: string;
data: Array<{
object: "embedding";
embedding: number[];
index: number;
}>;
model: string;
request_id: string;
}
/**
* Class for generating embeddings using the TogetherAI API. Extends the
* Embeddings class and implements TogetherAIEmbeddingsParams.
* @example
* ```typescript
* const embeddings = new TogetherAIEmbeddings({
* apiKey: process.env.TOGETHER_AI_API_KEY, // Default value
* model: "togethercomputer/m2-bert-80M-8k-retrieval", // Default value
* });
* const res = await embeddings.embedQuery(
* "What would be a good company name a company that makes colorful socks?"
* );
* ```
*/
export class TogetherAIEmbeddings
extends Embeddings
implements TogetherAIEmbeddingsParams
{
modelName = "togethercomputer/m2-bert-80M-8k-retrieval";
model = "togethercomputer/m2-bert-80M-8k-retrieval";
apiKey: string;
batchSize = 512;
stripNewLines = false;
timeout?: number;
private embeddingsAPIUrl = "https://api.together.xyz/api/v1/embeddings";
constructor(fields?: Partial<TogetherAIEmbeddingsParams>) {
super(fields ?? {});
const apiKey =
fields?.apiKey ?? getEnvironmentVariable("TOGETHER_AI_API_KEY");
if (!apiKey) {
throw new Error("TOGETHER_AI_API_KEY not found.");
}
this.apiKey = apiKey;
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.timeout = fields?.timeout;
this.batchSize = fields?.batchSize ?? this.batchSize;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
}
private constructHeaders() {
return {
accept: "application/json",
"content-type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
};
}
private constructBody(input: string) {
const body = {
model: this?.model,
input,
};
return body;
}
/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the TogetherAI API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize
);
let batchResponses: TogetherAIEmbeddingsResult[] = [];
for await (const batch of batches) {
const batchRequests = batch.map((item) => this.embeddingWithRetry(item));
const response = await Promise.all(batchRequests);
batchResponses = batchResponses.concat(response);
}
const embeddings: number[][] = batchResponses.map(
(response) => response.data[0].embedding
);
return embeddings;
}
/**
* Method to generate an embedding for a single document. Calls the
* embeddingWithRetry method with the document as the input.
* @param {string} text Document to generate an embedding for.
* @returns {Promise<number[]>} Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const { data } = await this.embeddingWithRetry(
this.stripNewLines ? text.replace(/\n/g, " ") : text
);
return data[0].embedding;
}
/**
* Private method to make a request to the TogetherAI API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param {string} input The input text to embed.
* @returns Promise that resolves to the response from the API.
* @TODO Figure out return type and statically type it.
*/
private async embeddingWithRetry(
input: string
): Promise<TogetherAIEmbeddingsResult> {
const body = JSON.stringify(this.constructBody(input));
const headers = this.constructHeaders();
return this.caller.call(async () => {
const fetchResponse = await fetch(this.embeddingsAPIUrl, {
method: "POST",
headers,
body,
});
if (fetchResponse.status === 200) {
return fetchResponse.json();
}
throw new Error(
`Error getting prompt completion from Together AI. ${JSON.stringify(
await fetchResponse.json(),
null,
2
)}`
);
});
}
}