-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
premai.ts
121 lines (102 loc) Β· 3.58 KB
/
premai.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import Prem from "@premai/prem-sdk";
/**
* Interface for PremEmbeddings parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the PremEmbeddings class.
*/
export interface PremEmbeddingsParams extends EmbeddingsParams {
/**
* The Prem API key to use for requests.
* @default process.env.PREM_API_KEY
*/
apiKey?: string;
baseUrl?: string;
/**
* The ID of the project to use.
*/
project_id?: number | string;
/**
* The model to generate the embeddings.
*/
model: string;
encoding_format?: ("float" | "base64") & string;
batchSize?: number;
}
/**
* Class for generating embeddings using the Prem AI's API. Extends the
* Embeddings class and implements PremEmbeddingsParams and
*/
export class PremEmbeddings extends Embeddings implements PremEmbeddingsParams {
client: Prem;
batchSize = 128;
apiKey?: string;
project_id: number;
model: string;
encoding_format?: ("float" | "base64") & string;
constructor(fields: PremEmbeddingsParams) {
super(fields);
const apiKey = fields?.apiKey || getEnvironmentVariable("PREM_API_KEY");
if (!apiKey) {
throw new Error(
`Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`
);
}
const projectId =
fields?.project_id ??
parseInt(getEnvironmentVariable("PREM_PROJECT_ID") ?? "-1", 10);
if (!projectId || projectId === -1 || typeof projectId !== "number") {
throw new Error(
`Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`
);
}
this.client = new Prem({
apiKey,
});
this.project_id = projectId;
this.model = fields.model ?? this.model;
this.encoding_format = fields.encoding_format ?? this.encoding_format;
}
/**
* Method to generate embeddings for an array of documents. Splits the
* documents into batches and makes requests to the Prem API to generate
* embeddings.
* @param texts Array of documents to generate embeddings for.
* @returns Promise that resolves to a 2D array of embeddings for each document.
*/
async embedDocuments(texts: string[]): Promise<number[][]> {
const mappedTexts = texts.map((text) => text);
const batches = chunkArray(mappedTexts, this.batchSize);
const batchRequests = batches.map((batch) =>
this.caller.call(async () =>
this.client.embeddings.create({
input: batch,
model: this.model,
encoding_format: this.encoding_format,
project_id: this.project_id,
})
)
);
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { data: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j].embedding);
}
}
return embeddings;
}
/**
* Method to generate an embedding for a single document. Calls the
* embedDocuments method with the document as the input.
* @param text Document to generate an embedding for.
* @returns Promise that resolves to an embedding for the document.
*/
async embedQuery(text: string): Promise<number[]> {
const data = await this.embedDocuments([text]);
return data[0];
}
}