-
Notifications
You must be signed in to change notification settings - Fork 6
/
cohere.go
139 lines (115 loc) 路 3.49 KB
/
cohere.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package embedding
import (
"context"
"errors"
"github.com/avast/retry-go"
cohere "github.com/cohere-ai/cohere-go/v2"
cohereclient "github.com/cohere-ai/cohere-go/v2/client"
core "github.com/cohere-ai/cohere-go/v2/core"
"github.com/hupe1980/golc/internal/util"
"github.com/hupe1980/golc/schema"
)
// Compile time check to ensure Cohere satisfies the Embedder interface.
var _ schema.Embedder = (*Cohere)(nil)
// CohereClient is an interface for the Cohere client.
type CohereClient interface {
Embed(ctx context.Context, request *cohere.EmbedRequest, opts ...core.RequestOption) (*cohere.EmbedResponse, error)
}
// CohereOptions contains options for configuring the Cohere instance.
type CohereOptions struct {
// Model name to use.
Model string
// Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")
Truncate string
// MaxRetries represents the maximum number of retries to make when embedding.
MaxRetries uint `map:"max_retries,omitempty"`
}
// Cohere is a client for the Cohere API.
type Cohere struct {
client CohereClient
opts CohereOptions
}
// NewCohere creates a new Cohere instance with the provided API key and options.
// It returns the initialized Cohere instance or an error if initialization fails.
func NewCohere(apiKey string, optFns ...func(o *CohereOptions)) *Cohere {
client := cohereclient.NewClient(cohereclient.WithToken(apiKey))
return NewCohereFromClient(client, optFns...)
}
// NewCohereFromClient creates a new Cohere instance from an existing Cohere client and options.
// It returns the initialized Cohere instance.
func NewCohereFromClient(client CohereClient, optFns ...func(o *CohereOptions)) *Cohere {
opts := CohereOptions{
Model: "embed-english-v2.0",
MaxRetries: 3,
Truncate: "NONE",
}
for _, fn := range optFns {
fn(&opts)
}
return &Cohere{
client: client,
opts: opts,
}
}
// BatchEmbedText embeds a list of texts and returns their embeddings.
func (e *Cohere) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) {
truncate, err := cohere.NewEmbedRequestTruncateFromString(e.opts.Truncate)
if err != nil {
return nil, err
}
res, err := e.embedWithRetry(ctx, &cohere.EmbedRequest{
Model: util.AddrOrNil(e.opts.Model),
Truncate: truncate.Ptr(),
Texts: texts,
EmbeddingTypes: []cohere.EmbeddingType{
cohere.EmbeddingTypeFloat,
},
})
if err != nil {
return nil, err
}
embeddings := make([][]float32, len(res.EmbeddingsByType.Embeddings.Float))
for i, r := range res.EmbeddingsByType.Embeddings.Float {
embeddings[i] = util.Float64ToFloat32(r)
}
return embeddings, nil
}
func (e *Cohere) embedWithRetry(ctx context.Context, req *cohere.EmbedRequest) (*cohere.EmbedResponse, error) {
retryOpts := []retry.Option{
retry.Attempts(e.opts.MaxRetries),
retry.DelayType(retry.FixedDelay),
retry.RetryIf(func(err error) bool {
e := new(core.APIError)
if errors.As(err, &e) {
switch e.StatusCode {
case 429, 500:
return true
default:
return false
}
}
return false
}),
}
var res *cohere.EmbedResponse
err := retry.Do(
func() error {
r, cErr := e.client.Embed(ctx, req)
if cErr != nil {
return cErr
}
res = r
return nil
},
retryOpts...,
)
return res, err
}
// EmbedText embeds a single query and returns its embedding.
func (e *Cohere) EmbedText(ctx context.Context, text string) ([]float32, error) {
embeddings, err := e.BatchEmbedText(ctx, []string{text})
if err != nil {
return nil, err
}
return embeddings[0], nil
}