-
Notifications
You must be signed in to change notification settings - Fork 42
/
cohere-rerank.go
121 lines (103 loc) 路 2.88 KB
/
cohere-rerank.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package transformer
import (
"context"
"os"
coherego "github.com/henomis/cohere-go"
"github.com/henomis/cohere-go/model"
"github.com/henomis/cohere-go/request"
"github.com/henomis/cohere-go/response"
"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/types"
)
type CohereRerankModel = model.RerankModel
const (
defaultCohereRerankMaxChunksPerDoc = 10
defaultCohereRerankTopN = -1
CohereRerankScoreMetdataKey = "cohere-rerank-score"
CohereRerankModelEnglishV20 CohereRerankModel = model.RerankModelEnglishV20
CohereRerankModelMultilingualV20 CohereRerankModel = model.RerankModelMultilingualV20
defaultCohereRerankModel = CohereRerankModelEnglishV20
)
type CohereRerank struct {
client *coherego.Client
maxChunksPerDoc int
topN int
model CohereRerankModel
}
func NewCohereRerank() *CohereRerank {
return &CohereRerank{
client: coherego.New(os.Getenv("COHERE_API_KEY")),
maxChunksPerDoc: defaultCohereRerankMaxChunksPerDoc,
topN: defaultCohereRerankTopN,
model: defaultCohereRerankModel,
}
}
func (c *CohereRerank) WithMaxChunksPerDoc(maxChunksPerDoc int) *CohereRerank {
c.maxChunksPerDoc = maxChunksPerDoc
return c
}
func (c *CohereRerank) WithAPIKey(apiKey string) *CohereRerank {
c.client = coherego.New(apiKey)
return c
}
func (c *CohereRerank) WithTopN(topN int) *CohereRerank {
c.topN = topN
return c
}
func (c *CohereRerank) WithModel(model CohereRerankModel) *CohereRerank {
c.model = model
return c
}
func (c *CohereRerank) Rerank(
ctx context.Context,
query string, documents []document.Document,
) ([]document.Document, error) {
if c.topN == defaultCohereRerankTopN {
c.topN = len(documents)
}
resp := &response.Rerank{}
err := c.client.Rerank(
ctx,
&request.Rerank{
ReturnDocuments: false,
MaxChunksPerDoc: &c.maxChunksPerDoc,
Query: query,
Documents: c.documentsToStringSlice(documents),
TopN: &c.topN,
},
resp,
)
if err != nil {
return nil, err
}
return c.rerankDocuments(documents, resp.Results), nil
}
func (c *CohereRerank) documentsToStringSlice(documents []document.Document) []string {
strings := make([]string, len(documents))
for i, d := range documents {
strings[i] = d.Content
}
return strings
}
func (c *CohereRerank) rerankDocuments(
documents []document.Document,
results []model.RerankResult,
) []document.Document {
rerankedDocuments := make([]document.Document, 0)
for _, result := range results {
index := result.Index
metadata := documents[index].Metadata
if metadata == nil {
metadata = make(types.Meta)
}
metadata[CohereRerankScoreMetdataKey] = result.RelevanceScore
rerankedDocuments = append(
rerankedDocuments,
document.Document{
Content: documents[index].Content,
Metadata: metadata,
},
)
}
return rerankedDocuments
}