-
Notifications
You must be signed in to change notification settings - Fork 2
/
claude.go
102 lines (88 loc) · 2.16 KB
/
claude.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package mutators
import (
"io"
"strconv"
"github.com/batmac/ccat/pkg/log"
"github.com/batmac/ccat/pkg/miniclaude"
"github.com/batmac/ccat/pkg/secretprovider"
)
func init() {
singleRegister("claude", claude,
withDescription("ask Anthropic Claude, X:<unlimited> max replied tokens, optional second arg is the model, optional third arg is the preprompt (needs a valid key in $ANTHROPIC_API_KEY)"),
withConfigBuilder(stdConfigStrings(0, 3)),
withHintSlow(), // output asap (when no other mutator is used),
withCategory("external APIs"),
)
}
func claude(w io.WriteCloser, r io.ReadCloser, conf any) (int64, error) {
args := conf.([]string)
maxTokens := 1000
var err error
if len(args) > 0 && args[0] != "" {
maxTokens, err = strconv.Atoi(args[0])
if err != nil {
log.Println("first arg: ", err)
}
}
model := miniclaude.ModelClaude3Haiku
if len(args) >= 2 && args[1] != "" {
model = args[1]
}
prePrompt := ""
if len(args) >= 3 && args[2] != "" {
prePrompt = args[2] + ":\n"
}
log.Debugln("model: ", model)
log.Debugln("maxTokens: ", maxTokens)
log.Debugln("prePrompt (system): ", prePrompt)
key, _ := secretprovider.GetSecret("anthropic", "ANTHROPIC_API_KEY")
if key == "" {
log.Fatal("ANTHROPIC_API_KEY environment variable is not set")
}
prompt, err := io.ReadAll(r)
if err != nil {
return 0, err
}
mr := &miniclaude.MessagesRequest{
Model: model,
Messages: []miniclaude.Message{
{
Role: miniclaude.RoleUser,
Content: []miniclaude.ContentBlock{
{
Type: miniclaude.ContentTypeText,
Text: string(prompt),
},
},
},
},
MaxTokens: maxTokens,
}
if prePrompt != "" {
mr.System = prePrompt
}
request := miniclaude.NewMessagesRequest()
request.APIKey = key
go func() {
if key == "CI" {
log.Println("ANTHROPIC_API_KEY is set to CI, using fake response")
request.C <- "fake"
request.C <- ""
close(request.C)
return
}
err := request.Stream(mr)
if err != nil {
log.Println("request.Stream: ", err)
}
}()
var total int64
for s := range request.C {
if s == "" {
log.Debugln("empty string")
}
n, _ := w.Write([]byte(s))
total += int64(n)
}
return total, nil
}