-
Notifications
You must be signed in to change notification settings - Fork 0
/
openai.go
83 lines (62 loc) · 1.47 KB
/
openai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package chatter
import (
"context"
"errors"
"io"
"github.com/sashabaranov/go-openai"
)
type OpenAI struct {
APIKey string
Model string
}
type OpenAIChatter struct {
client *openai.Client
model string
dialog []openai.ChatCompletionMessage
}
func NewOpenAIChatter(cCfg OpenAI) (*OpenAIChatter, error) {
client := openai.NewClient(cCfg.APIKey)
return &OpenAIChatter{client: client, model: cCfg.Model}, nil
}
func (c *OpenAIChatter) Model() string {
return c.model
}
func (c *OpenAIChatter) Close() error {
return nil
}
func (c *OpenAIChatter) MakeSynchronousTextQuery(ctx context.Context, console Console, prompt string) (string, error) {
message := openai.ChatCompletionMessage{ //nolint:exhaustruct
Role: openai.ChatMessageRoleUser,
Content: prompt,
}
c.dialog = append(c.dialog, message)
resp, errCc := c.client.CreateChatCompletionStream(
ctx,
openai.ChatCompletionRequest{ //nolint:exhaustruct
Model: c.model,
Messages: c.dialog,
},
)
if errCc != nil {
return "", errCc
}
var responseText string
for {
response, errRecv := resp.Recv()
if errRecv != nil {
if !errors.Is(errRecv, io.EOF) {
return "", errRecv
}
break
}
responseChunk := response.Choices[0].Delta.Content
responseText += responseChunk
console.Print(responseChunk)
}
c.dialog = append(c.dialog, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: responseText,
})
console.Println()
return responseText, nil
}