Skip to content

Commit

Permalink
feat: 添加usage属性以适配下游依赖该属性的请求
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Apr 15, 2024
1 parent c1aab4e commit 8241fb3
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 88 deletions.
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ GOOS=windows GOARCH=amd64 ${cmd} ${args} -ldflags '-w -s' -o ${outdir}/win-serve
GOOS=linux GOARCH=amd64 ${cmd} ${args} -ldflags '-w -s' -o ${outdir}/linux-server server.go
GOARM=7 GOOS=linux GOARCH=arm64 ${cmd} ${args} -ldflags '-w -s' -o ${outdir}/linux-server-arm64 server.go

# cp .env.example ${outdir}/.env.example
# cp .env.example ${outdir}/.env.example
cp config.yaml $outdir/config.yaml
15 changes: 12 additions & 3 deletions internal/common/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@ import (
"github.com/sirupsen/logrus"
)

// 计算prompt的token长度
func CalcTokens(prompt string) int {
// 计算content的token长度
func CalcTokens(content string) int {
resolver, err := encoder.NewEncoder()
if err != nil {
logrus.Error(err)
return 0
}
result, err := resolver.Encode(prompt)
result, err := resolver.Encode(content)
if err != nil {
logrus.Error(err)
return 0
}
return len(result)
}

func CalcUsageTokens(content string, previousTokens int) map[string]int {
tokens := CalcTokens(content)
return map[string]int{
"completion_tokens": tokens,
"prompt_tokens": previousTokens,
"total_tokens": previousTokens + tokens,
}
}
23 changes: 14 additions & 9 deletions internal/middle/bing/assembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ func Complete(ctx *gin.Context, req gpt.ChatCompletionRequest, matchers []common
}
}

pMessages, prompt, err := buildConversation(pad, messages)
pMessages, prompt, tokens, err := buildConversation(pad, messages)
if err != nil {
middle.ResponseWithE(ctx, -1, err)
return
}

ctx.Set("tokens", tokens)
// 清理多余的标签
matchers = appendMatchers(matchers)
chat := edge.New(options.
Expand Down Expand Up @@ -168,6 +169,7 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
pos = 0
content = ""
created = time.Now().Unix()
tokens = ctx.GetInt("tokens")
)

logrus.Info("waitResponse ...")
Expand All @@ -192,20 +194,19 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
raw = common.ExecMatchers(matchers, raw)

if sse {
middle.ResponseWithSSE(ctx, MODEL, raw, created)
} else {
content += raw
middle.ResponseWithSSE(ctx, MODEL, raw, nil, created)
}
content += raw
}

if !sse {
middle.ResponseWith(ctx, MODEL, content)
} else {
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", created)
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", common.CalcUsageTokens(content, tokens), created)
}
}

func buildConversation(pad bool, messages []map[string]string) (pMessages []edge.ChatMessage, prompt string, err error) {
func buildConversation(pad bool, messages []map[string]string) (pMessages []edge.ChatMessage, prompt string, tokens int, err error) {
pos := len(messages) - 1
if pos < 0 {
return
Expand Down Expand Up @@ -260,6 +261,7 @@ func buildConversation(pad bool, messages []map[string]string) (pMessages []edge
for {
if pos >= messageL {
if len(buffer) > 0 {
tokens += common.CalcTokens(strings.Join(buffer, ""))
pMessagesVar = append(pMessagesVar, blockProcessing(strings.Title(role), buffer))
}
break
Expand All @@ -269,7 +271,7 @@ func buildConversation(pad bool, messages []map[string]string) (pMessages []edge
curr := condition(message["role"])
content := message["content"]
if curr == "" {
return nil, "", errors.New(
return nil, "", -1, errors.New(
fmt.Sprintf("'%s' is not one of ['system', 'assistant', 'user', 'function'] - 'messages.%d.role'",
message["role"], pos))
}
Expand All @@ -286,6 +288,8 @@ func buildConversation(pad bool, messages []map[string]string) (pMessages []edge
buffer = append(buffer, content)
continue
}

tokens += common.CalcTokens(strings.Join(buffer, ""))
pMessagesVar = append(pMessagesVar, blockProcessing(strings.Title(role), buffer))
buffer = append(make([]string, 0), content)
role = curr
Expand All @@ -300,7 +304,7 @@ func buildConversation(pad bool, messages []map[string]string) (pMessages []edge
dict["messages"] = pMessagesVar
indent, e := json.MarshalIndent(dict, "", " ")
if e != nil {
return nil, "", e
return nil, "", -1, e
}

if pad { // 填充引导对话,尝试避免道歉
Expand Down Expand Up @@ -335,5 +339,6 @@ func buildConversation(pad bool, messages []map[string]string) (pMessages []edge
})
}

return pMessages, prompt, nil
tokens += common.CalcTokens(prompt)
return pMessages, prompt, tokens, nil
}
2 changes: 1 addition & 1 deletion internal/middle/bing/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func parseToToolCall(ctx *gin.Context, cookie, proxies string, fun *gpt.Function
} else {
// 没有解析出 JSON
if sse {
middle.ResponseWithSSE(ctx, MODEL, content, created)
middle.ResponseWithSSE(ctx, MODEL, content, nil, created)
return false, nil
} else {
middle.ResponseWith(ctx, MODEL, content)
Expand Down
18 changes: 10 additions & 8 deletions internal/middle/claude/assembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ func Complete(ctx *gin.Context, req gpt.ChatCompletionRequest, matchers []common
}
}

attr, err := buildConversation(messages)
attr, tokens, err := buildConversation(messages)
if err != nil {
middle.ResponseWithE(ctx, -1, err)
return
}

ctx.Set("tokens", tokens)
chat, err := claude2.New(options)
if err != nil {
middle.ResponseWithE(ctx, -1, err)
Expand Down Expand Up @@ -97,6 +98,7 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
var (
content = ""
created = time.Now().Unix()
tokens = ctx.GetInt("tokens")
)
logrus.Infof("waitResponse ...")

Expand All @@ -114,20 +116,19 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
fmt.Printf("----- raw -----\n %s\n", message.Text)
raw := common.ExecMatchers(matchers, message.Text)
if sse {
middle.ResponseWithSSE(ctx, MODEL, raw, created)
} else {
content += raw
middle.ResponseWithSSE(ctx, MODEL, raw, nil, created)
}
content += raw
}

if !sse {
middle.ResponseWith(ctx, MODEL, content)
} else {
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", created)
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", common.CalcUsageTokens(content, tokens), created)
}
}

func buildConversation(messages []map[string]string) (attrs []types.Attachment, err error) {
func buildConversation(messages []map[string]string) (attrs []types.Attachment, tokens int, err error) {
pos := len(messages) - 1
if pos < 0 {
return
Expand Down Expand Up @@ -185,7 +186,7 @@ func buildConversation(messages []map[string]string) (attrs []types.Attachment,
curr := condition(message["role"])
content := message["content"]
if curr == "" {
return nil, errors.New(
return nil, -1, errors.New(
fmt.Sprintf("'%s' is not one of ['system', 'assistant', 'user', 'function'] - 'messages.%d.role'",
message["role"], pos))
}
Expand Down Expand Up @@ -216,6 +217,7 @@ func buildConversation(messages []map[string]string) (attrs []types.Attachment,
pMessages = fmt.Sprintf("%s\n--------\n\n%s", s, pMessages)
}

tokens = common.CalcTokens(pMessages)
attrs = append(attrs, types.Attachment{
Content: pMessages,
FileName: "paste.txt",
Expand All @@ -224,7 +226,7 @@ func buildConversation(messages []map[string]string) (attrs []types.Attachment,
})
}

return attrs, nil
return attrs, tokens, nil
}

func padtxt(length int) string {
Expand Down
2 changes: 1 addition & 1 deletion internal/middle/claude/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func parseToToolCall(ctx *gin.Context, cookie, proxies, model string, fun *gpt.F
} else {
// 没有解析出 JSON
if sse {
middle.ResponseWithSSE(ctx, MODEL, content, created)
middle.ResponseWithSSE(ctx, MODEL, content, nil, created)
return false, nil
} else {
middle.ResponseWith(ctx, MODEL, content)
Expand Down
18 changes: 11 additions & 7 deletions internal/middle/cohere/assembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ func Complete(ctx *gin.Context, req gpt.ChatCompletionRequest, matchers []common
"system:",
})
} else {
p, s, m, err := buildChatConversation(messages)
p, s, m, tokens, err := buildChatConversation(messages)
if err != nil {
middle.ResponseWithE(ctx, -1, err)
return
}

ctx.Set("tokens", tokens)

system = s
message = m
pMessages = p
Expand Down Expand Up @@ -117,6 +119,7 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
logrus.Infof("waitResponse ...")
prefix := ""
cmd := ctx.GetInt("cmd")
tokens := ctx.GetInt("tokens")

for {
raw, ok := <-chatResponse
Expand Down Expand Up @@ -155,16 +158,15 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
}
raw = common.ExecMatchers(matchers, raw)
if sse {
middle.ResponseWithSSE(ctx, MODEL, raw, created)
} else {
content += raw
middle.ResponseWithSSE(ctx, MODEL, raw, nil, created)
}
content += raw
}

if !sse {
middle.ResponseWith(ctx, MODEL, content)
} else {
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", created)
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", common.CalcUsageTokens(content, tokens), created)
}
}

Expand Down Expand Up @@ -242,7 +244,7 @@ func buildConversation(messages []map[string]string) (content string, err error)
return cohere.MergeMessages(pMessages), nil
}

func buildChatConversation(messages []map[string]string) (pMessages []cohere.Message, system, content string, err error) {
func buildChatConversation(messages []map[string]string) (pMessages []cohere.Message, system, content string, tokens int, err error) {
pos := len(messages) - 1
if pos < 0 {
return
Expand Down Expand Up @@ -313,7 +315,7 @@ func buildChatConversation(messages []map[string]string) (pMessages []cohere.Mes
curr := condition(message["role"])
tMessage := message["content"]
if curr == "" {
return nil, "", "", errors.New(
return nil, "", "", -1, errors.New(
fmt.Sprintf("'%s' is not one of ['system', 'assistant', 'user', 'function'] - 'messages.%d.role'",
message["role"], pos))
}
Expand Down Expand Up @@ -352,6 +354,7 @@ func buildChatConversation(messages []map[string]string) (pMessages []cohere.Mes
for {
if pos >= messageL {
join := strings.Join(buffer, "\n\n")
tokens += common.CalcTokens(join)
pMessages = append(pMessages, cohere.Message{
Role: role,
Message: join,
Expand All @@ -373,6 +376,7 @@ func buildChatConversation(messages []map[string]string) (pMessages []cohere.Mes
continue
}

tokens += common.CalcTokens(strings.Join(buffer, ""))
pMessages = append(pMessages, cohere.Message{
Role: role,
Message: strings.Join(buffer, "\n\n"),
Expand Down
2 changes: 1 addition & 1 deletion internal/middle/cohere/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func parseToToolCall(ctx *gin.Context, cookie string, req gpt.ChatCompletionRequ
} else {
// 没有解析出 JSON
if req.Stream {
middle.ResponseWithSSE(ctx, MODEL, content, created)
middle.ResponseWithSSE(ctx, MODEL, content, nil, created)
return false, nil
} else {
middle.ResponseWith(ctx, MODEL, content)
Expand Down
20 changes: 12 additions & 8 deletions internal/middle/coze/assembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ func Complete(ctx *gin.Context, req gpt.ChatCompletionRequest, matchers []common
}
}

pMessages, err := buildConversation(messages)
pMessages, tokens, err := buildConversation(messages)
if err != nil {
middle.ResponseWithE(ctx, -1, err)
return
}

ctx.Set("tokens", tokens)
options := newOptions(proxies, pMessages)
co, msToken := extCookie(cookie)
chat := coze.New(co, msToken, options)
Expand Down Expand Up @@ -174,6 +175,7 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
logrus.Infof("waitResponse ...")
prefix := ""
cmd := ctx.GetInt("cmd")
tokens := ctx.GetInt("tokens")

for {
raw, ok := <-chatResponse
Expand Down Expand Up @@ -212,20 +214,19 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
}
raw = common.ExecMatchers(matchers, raw)
if sse {
middle.ResponseWithSSE(ctx, MODEL, raw, created)
} else {
content += raw
middle.ResponseWithSSE(ctx, MODEL, raw, nil, created)
}
content += raw
}

if !sse {
middle.ResponseWith(ctx, MODEL, content)
} else {
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", created)
middle.ResponseWithSSE(ctx, MODEL, "[DONE]", common.CalcUsageTokens(content, tokens), created)
}
}

func buildConversation(messages []map[string]string) (pMessages []coze.Message, err error) {
func buildConversation(messages []map[string]string) (pMessages []coze.Message, tokens int, err error) {
var prompt string
pos := len(messages) - 1
if pos < 0 {
Expand Down Expand Up @@ -273,6 +274,7 @@ func buildConversation(messages []map[string]string) (pMessages []coze.Message,
for {
if pos >= messageL {
if len(buffer) > 0 {
tokens += common.CalcTokens(strings.Join(buffer, ""))
pMessages = append(pMessages, coze.Message{
Role: role,
Content: strings.Join(buffer, "\n\n"),
Expand All @@ -285,7 +287,7 @@ func buildConversation(messages []map[string]string) (pMessages []coze.Message,
curr := condition(message["role"])
content := message["content"]
if curr == "" {
return nil, errors.New(
return nil, -1, errors.New(
fmt.Sprintf("'%s' is not one of ['system', 'assistant', 'user', 'function'] - 'messages.%d.role'",
message["role"], pos))
}
Expand All @@ -303,6 +305,8 @@ func buildConversation(messages []map[string]string) (pMessages []coze.Message,
buffer = append(buffer, content)
continue
}

tokens += common.CalcTokens(strings.Join(buffer, ""))
pMessages = append(pMessages, coze.Message{
Role: role,
Content: strings.Join(buffer, "\n\n"),
Expand All @@ -311,5 +315,5 @@ func buildConversation(messages []map[string]string) (pMessages []coze.Message,
role = curr
}

return pMessages, nil
return pMessages, tokens, nil
}
2 changes: 1 addition & 1 deletion internal/middle/coze/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func parseToToolCall(ctx *gin.Context, cookie, proxies string, fun *gpt.Function
} else {
// 没有解析出 JSON
if sse {
middle.ResponseWithSSE(ctx, MODEL, content, created)
middle.ResponseWithSSE(ctx, MODEL, content, nil, created)
return false, nil
} else {
middle.ResponseWith(ctx, MODEL, content)
Expand Down
Loading

0 comments on commit 8241fb3

Please sign in to comment.