/
cohere.go
95 lines (78 loc) · 2.66 KB
/
cohere.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package embedding
import (
"context"
cohere "github.com/cohere-ai/cohere-go/v2"
cohereclient "github.com/cohere-ai/cohere-go/v2/client"
)
type CohereEmbeddingProvider struct {
subType string
secretKey string
inputType string
}
func (p *CohereEmbeddingProvider) GetPricing() string {
return `URL:
https://cohere.com/pricing
Embedding models:
| Models | Per 1,000,000 tokens |
|---------|----------------------|
| default | $0.1 |
`
}
func (p *CohereEmbeddingProvider) calculatePrice(res *EmbeddingResult) error {
pricePerThousandTokens := 0.0001
res.Price = getPrice(res.TokenCount, pricePerThousandTokens)
res.Currency = "USD"
return nil
}
func NewCohereEmbeddingProvider(subType string, inputType string, secretKey string) (*CohereEmbeddingProvider, error) {
return &CohereEmbeddingProvider{
subType: subType,
secretKey: secretKey,
inputType: inputType,
}, nil
}
func (p *CohereEmbeddingProvider) QueryVector(text string, ctx context.Context) ([]float32, *EmbeddingResult, error) {
client := cohereclient.NewClient(
cohereclient.WithToken(p.secretKey),
)
embeddingResult, embed, err := cohereEmbed(ctx, client, &p.subType, &p.inputType, []string{text})
if err != nil {
return nil, nil, err
}
err = p.calculatePrice(embeddingResult)
if err != nil {
return nil, nil, err
}
vector := float64ToFloat32(embed[0])
return vector, embeddingResult, nil
}
func cohereEmbed(ctx context.Context, client *cohereclient.Client, model *string, inputType *string, texts []string) (*EmbeddingResult, [][]float64, error) {
resp, err := client.Embed(ctx, &cohere.EmbedRequest{
Texts: texts,
Model: model,
InputType: (*cohere.EmbedInputType)(inputType),
})
if err != nil {
return nil, nil, err
}
tokenCount := int(*resp.EmbeddingsFloats.Meta.BilledUnits.InputTokens)
embeddingResult := &EmbeddingResult{TokenCount: tokenCount}
embeddings := make([][]float64, len(resp.EmbeddingsFloats.Embeddings))
for i, embedding := range resp.EmbeddingsFloats.Embeddings {
embeddings[i] = embedding
}
return embeddingResult, embeddings, nil
}