-
Notifications
You must be signed in to change notification settings - Fork 2k
/
rerank.ts
127 lines (117 loc) Β· 3.73 KB
/
rerank.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import { DocumentInterface } from "@langchain/core/documents";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { CohereClient } from "cohere-ai";
export interface CohereRerankArgs {
/**
* The API key to use.
* @default {process.env.COHERE_API_KEY}
*/
apiKey?: string;
/**
* The name of the model to use.
* @default {"rerank-english-v2.0"}
*/
model?: string;
/**
* How many documents to return.
* @default {3}
*/
topN?: number;
/**
* The maximum number of chunks per document.
*/
maxChunksPerDoc?: number;
}
/**
* Document compressor that uses `Cohere Rerank API`.
*/
export class CohereRerank {
model = "rerank-english-v2.0";
topN = 3;
client: CohereClient;
maxChunksPerDoc: number | undefined;
constructor(fields?: CohereRerankArgs) {
const token = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY");
if (!token) {
throw new Error("No API key provided for CohereRerank.");
}
this.client = new CohereClient({
token,
});
this.model = fields?.model ?? this.model;
this.topN = fields?.topN ?? this.topN;
this.maxChunksPerDoc = fields?.maxChunksPerDoc;
}
/**
* Compress documents using Cohere's rerank API.
*
* @param {Array<DocumentInterface>} documents A sequence of documents to compress.
* @param {string} query The query to use for compressing the documents.
*
* @returns {Promise<Array<DocumentInterface>>} A sequence of compressed documents.
*/
async compressDocuments(
documents: Array<DocumentInterface>,
query: string
): Promise<Array<DocumentInterface>> {
const _docs = documents.map((doc) => doc.pageContent);
const { results } = await this.client.rerank({
model: this.model,
query,
documents: _docs,
topN: this.topN,
maxChunksPerDoc: this.maxChunksPerDoc,
});
const finalResults: Array<DocumentInterface> = [];
for (let i = 0; i < results.length; i += 1) {
const result = results[i];
const doc = documents[result.index];
doc.metadata.relevanceScore = result.relevanceScore;
finalResults.push(doc);
}
return finalResults;
}
/**
* Returns an ordered list of documents ordered by their relevance to the provided query.
*
* @param {Array<DocumentInterface | string | Record<string, string>>} documents A list of documents as strings, DocumentInterfaces or objects with a `pageContent` key.
* @param {string} query The query to use for reranking the documents.
* @param options
* @param {string} options.model The name of the model to use.
* @param {number} options.topN How many documents to return.
* @param {number} options.maxChunksPerDoc The maximum number of chunks per document.
*
* @returns {Promise<Array<{ index: number; relevanceScore: number }>>} An ordered list of documents with relevance scores.
*/
async rerank(
documents: Array<DocumentInterface | string | Record<string, string>>,
query: string,
options?: {
model?: string;
topN?: number;
maxChunksPerDoc?: number;
}
): Promise<Array<{ index: number; relevanceScore: number }>> {
const docs = documents.map((doc) => {
if (typeof doc === "string") {
return doc;
}
return doc.pageContent;
});
const model = options?.model ?? this.model;
const topN = options?.topN ?? this.topN;
const maxChunksPerDoc = options?.maxChunksPerDoc ?? this.maxChunksPerDoc;
const { results } = await this.client.rerank({
model,
query,
documents: docs,
topN,
maxChunksPerDoc,
});
const resultObjects = results.map((result) => ({
index: result.index,
relevanceScore: result.relevanceScore,
}));
return resultObjects;
}
}