-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
rerank.ts
129 lines (119 loc) Β· 3.86 KB
/
rerank.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import { DocumentInterface } from "@langchain/core/documents";
import { BaseDocumentCompressor } from "@langchain/core/retrievers/document_compressors";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { CohereClient } from "cohere-ai";
export interface CohereRerankArgs {
/**
* The API key to use.
* @default {process.env.COHERE_API_KEY}
*/
apiKey?: string;
/**
* The name of the model to use.
* @default {"rerank-english-v2.0"}
*/
model?: string;
/**
* How many documents to return.
* @default {3}
*/
topN?: number;
/**
* The maximum number of chunks per document.
*/
maxChunksPerDoc?: number;
}
/**
* Document compressor that uses `Cohere Rerank API`.
*/
export class CohereRerank extends BaseDocumentCompressor {
model = "rerank-english-v2.0";
topN = 3;
client: CohereClient;
maxChunksPerDoc: number | undefined;
constructor(fields?: CohereRerankArgs) {
super();
const token = fields?.apiKey ?? getEnvironmentVariable("COHERE_API_KEY");
if (!token) {
throw new Error("No API key provided for CohereRerank.");
}
this.client = new CohereClient({
token,
});
this.model = fields?.model ?? this.model;
this.topN = fields?.topN ?? this.topN;
this.maxChunksPerDoc = fields?.maxChunksPerDoc;
}
/**
* Compress documents using Cohere's rerank API.
*
* @param {Array<DocumentInterface>} documents A sequence of documents to compress.
* @param {string} query The query to use for compressing the documents.
*
* @returns {Promise<Array<DocumentInterface>>} A sequence of compressed documents.
*/
async compressDocuments(
documents: Array<DocumentInterface>,
query: string
): Promise<Array<DocumentInterface>> {
const _docs = documents.map((doc) => doc.pageContent);
const { results } = await this.client.rerank({
model: this.model,
query,
documents: _docs,
topN: this.topN,
maxChunksPerDoc: this.maxChunksPerDoc,
});
const finalResults: Array<DocumentInterface> = [];
for (let i = 0; i < results.length; i += 1) {
const result = results[i];
const doc = documents[result.index];
doc.metadata.relevanceScore = result.relevanceScore;
finalResults.push(doc);
}
return finalResults;
}
/**
* Returns an ordered list of documents ordered by their relevance to the provided query.
*
* @param {Array<DocumentInterface | string | Record<string, string>>} documents A list of documents as strings, DocumentInterfaces or objects with a `pageContent` key.
* @param {string} query The query to use for reranking the documents.
* @param options
* @param {string} options.model The name of the model to use.
* @param {number} options.topN How many documents to return.
* @param {number} options.maxChunksPerDoc The maximum number of chunks per document.
*
* @returns {Promise<Array<{ index: number; relevanceScore: number }>>} An ordered list of documents with relevance scores.
*/
async rerank(
documents: Array<DocumentInterface | string | Record<string, string>>,
query: string,
options?: {
model?: string;
topN?: number;
maxChunksPerDoc?: number;
}
): Promise<Array<{ index: number; relevanceScore: number }>> {
const docs = documents.map((doc) => {
if (typeof doc === "string") {
return doc;
}
return doc.pageContent;
});
const model = options?.model ?? this.model;
const topN = options?.topN ?? this.topN;
const maxChunksPerDoc = options?.maxChunksPerDoc ?? this.maxChunksPerDoc;
const { results } = await this.client.rerank({
model,
query,
documents: docs,
topN,
maxChunksPerDoc,
});
const resultObjects = results.map((result) => ({
index: result.index,
relevanceScore: result.relevanceScore,
}));
return resultObjects;
}
}