-
Notifications
You must be signed in to change notification settings - Fork 42
/
chatgpt.go
86 lines (73 loc) · 2.15 KB
/
chatgpt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package openai
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
)
func CallChatGPT(cfg Config, inputMessages []ChatMessage, stream bool) (<-chan string, error) {
messageUpdates := make(chan string, 2)
// return a chan of all message updates here and listen here in the background in the event stream
go func() {
defer close(messageUpdates)
jsonData, _ := json.Marshal(ChatRequest{
Model: cfg.Model,
Temperature: cfg.Temperature,
Seed: cfg.Seed,
MaxTokens: cfg.MaxTokens,
Stream: stream,
Messages: inputMessages,
})
resp, err := doRequest(cfg, apiCompletionURL, jsonData)
if err != nil {
messageUpdates <- err.Error()
return
}
defer resp.Body.Close()
// some error occurred: we don't have an event stream but a single ChatResponse with an error
if !stream || resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
var chatResponse ChatResponse
err = json.Unmarshal(body, &chatResponse)
if err != nil {
log.Warnf("Openai Error %d: %s", resp.StatusCode, err)
messageUpdates <- fmt.Sprintf("Error %d: %s", resp.StatusCode, err)
return
}
if err = chatResponse.GetError(); err != nil {
log.Warn("Openai Error: ", err, chatResponse, body)
messageUpdates <- err.Error()
return
}
if message := chatResponse.GetMessage().Content; message != "" {
messageUpdates <- message
}
} else {
// stream: each line contains a delta of the message, so one new token
fileScanner := bufio.NewScanner(resp.Body)
fileScanner.Split(bufio.ScanLines)
for fileScanner.Scan() {
line := fileScanner.Text()
if _, deltaJSON, found := strings.Cut(line, "data: "); found {
if deltaJSON == "[DONE]" {
// end of event stream
return
}
var delta ChatResponse
err = json.Unmarshal([]byte(deltaJSON), &delta)
if err != nil {
log.Warnf("openai error in json: %s (json: %s)", err, deltaJSON)
continue
}
if deltaContent := delta.GetDelta().Content; deltaContent != "" {
messageUpdates <- deltaContent
}
}
}
}
}()
return messageUpdates, nil
}