-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
fireworks.ts
162 lines (134 loc) Β· 4.45 KB
/
fireworks.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the FireworksEmbeddings class.
*/
export interface FireworksEmbeddingsParams extends EmbeddingsParams {
modelName: string;
/**
* The maximum number of documents to embed in a single request. This is
* limited by the Fireworks AI API to a maximum of 8.
*/
batchSize?: number;
}
/**
* Interface for the request body to generate embeddings.
*/
export interface CreateFireworksEmbeddingRequest {
/**
* @type {string}
* @memberof CreateFireworksEmbeddingRequest
*/
model: string;
/**
* Text to generate vector expectation
* @type {CreateEmbeddingRequestInput}
* @memberof CreateFireworksEmbeddingRequest
*/
input: string | string[];
}
/**
* A class for generating embeddings using the Fireworks AI API.
*/
export class FireworksEmbeddings
extends Embeddings
implements FireworksEmbeddingsParams
{
modelName = "nomic-ai/nomic-embed-text-v1.5";
batchSize = 8;
private apiKey: string;
basePath?: string = "https://api.fireworks.ai/inference/v1";
apiUrl: string;
headers?: Record<string, string>;
/**
* Constructor for the FireworksEmbeddings class.
* @param fields - An optional object with properties to configure the instance.
*/
constructor(
fields?: Partial<FireworksEmbeddingsParams> & {
verbose?: boolean;
apiKey?: string;
}
) {
const fieldsWithDefaults = { ...fields };
super(fieldsWithDefaults);
const apiKey =
fieldsWithDefaults?.apiKey || getEnvironmentVariable("FIREWORKS_API_KEY");
if (!apiKey) {
throw new Error("Fireworks AI API key not found");
}
this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.apiKey = apiKey;
this.apiUrl = `${this.basePath}/embeddings`;
}
/**
* Generates embeddings for an array of texts.
* @param texts - An array of strings to generate embeddings for.
* @returns A Promise that resolves to an array of embeddings.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(texts, this.batchSize);
const batchRequests = batches.map((batch) =>
this.embeddingWithRetry({
model: this.modelName,
input: batch,
})
);
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { data: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j].embedding);
}
}
return embeddings;
}
/**
* Generates an embedding for a single text.
* @param text - A string to generate an embedding for.
* @returns A Promise that resolves to an array of numbers representing the embedding.
*/
async embedQuery(text: string): Promise<number[]> {
const { data } = await this.embeddingWithRetry({
model: this.modelName,
input: text,
});
return data[0].embedding;
}
/**
* Makes a request to the Fireworks AI API to generate embeddings for an array of texts.
* @param request - An object with properties to configure the request.
* @returns A Promise that resolves to the response from the Fireworks AI API.
*/
private async embeddingWithRetry(request: CreateFireworksEmbeddingRequest) {
const makeCompletionRequest = async () => {
const url = `${this.apiUrl}`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
...this.headers,
},
body: JSON.stringify(request),
});
if (!response.ok) {
const { error: message } = await response.json();
const error = new Error(
`Error ${response.status}: ${message ?? "Unspecified error"}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
const json = await response.json();
return json;
};
return this.caller.call(makeCompletionRequest);
}
}