Skip to content

Commit

Permalink
feat: 添加角色序列映射标记 char_sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Jun 24, 2024
1 parent 0d6a8a9 commit 58d85a8
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 46 deletions.
4 changes: 0 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ google:
# threshold: BLOCK_NONE

you:
enabled: true
helper: 8082
cookies:
- 'xxx1'
- 'xxx2'
serverless:
enabled: false
disabled-gpu: true
proxies: http://127.0.0.1:7890
headless: new

# coze 默认配置;;;内置配置经常变动,难以维护改为配置化。新增webSdk模式,但还未得知速率是否有限制???
Expand All @@ -59,13 +57,11 @@ serverless:

interpreter:
baseUrl: http://127.0.0.1:8000
useProxy: false
echoCode: false
# reverseUrl: ws://127.0.0.1:8000/ws TODO -

custom-llm:
baseUrl: http://127.0.0.1:8080
useProxy: false

# toolCall 默认配置化; 在 flags 关闭时也可用
toolCall:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/bincooo/coze-api v1.0.2-0.20240620163352-ad627f0dffd2
github.com/bincooo/edge-api v1.0.4-0.20240620163255-3f676529cccd
github.com/bincooo/emit.io v0.0.0-20240622171207-c5018480b050
github.com/bincooo/you.com v0.0.0-20240624043209-de5389f9bd83
github.com/bincooo/you.com v0.0.0-20240624182857-8b5f69e16877
github.com/dlclark/regexp2 v1.7.0
github.com/eko/gocache/lib/v4 v4.1.6
github.com/eko/gocache/store/go_cache/v4 v4.2.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ github.com/bincooo/emit.io v0.0.0-20240622171207-c5018480b050 h1:vVstCrqaYprk3l8
github.com/bincooo/emit.io v0.0.0-20240622171207-c5018480b050/go.mod h1:Tag5SFgOt/R1ZN9YZ/DDZK/elTWHf1mNnRv4oh4TMOE=
github.com/bincooo/requests v0.0.0-20230720064210-7eae5d6c9d1e h1:38ztKJW0K6qQGitqAlcJs8O2nal60qYoS9oNZnpkZWE=
github.com/bincooo/requests v0.0.0-20230720064210-7eae5d6c9d1e/go.mod h1:0WuzYU+4cQL/hVbjoncY5TACMTbD9I+pLCdnPjfItp0=
github.com/bincooo/you.com v0.0.0-20240624043209-de5389f9bd83 h1:K4zmRo2+90Otn8uf7CgGdsIRuDYGOMw4pQDycU7HhM0=
github.com/bincooo/you.com v0.0.0-20240624043209-de5389f9bd83/go.mod h1:1QlkSxFhT1VGFFbWMCM31wnF2Jzuagk6Or15B/sjxXw=
github.com/bincooo/you.com v0.0.0-20240624182857-8b5f69e16877 h1:mBjRBcdbGvM/msVZ2Ly4V0YPit5upeEFyKc3Y+qKrEk=
github.com/bincooo/you.com v0.0.0-20240624182857-8b5f69e16877/go.mod h1:1QlkSxFhT1VGFFbWMCM31wnF2Jzuagk6Or15B/sjxXw=
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
Expand Down
17 changes: 17 additions & 0 deletions internal/common/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/bincooo/chatgpt-adapter/internal/vars"
"github.com/bincooo/chatgpt-adapter/pkg"
"github.com/gin-gonic/gin"
"strings"
)
Expand Down Expand Up @@ -41,6 +42,22 @@ func NewMatchers() []Matcher {
func NewCancelMather(ctx *gin.Context) (chan error, Matcher) {
count := 0
cancel := make(chan error, 1)

newBlocks := make([]string, 0)
newBlocks = append(newBlocks, blocks...)

keyv, ok := GetGinValue[pkg.Keyv[string]](ctx, vars.GinCharSequences)
if ok {
user := keyv.GetString("user")
assistant := keyv.GetString("assistant")
if user != "" {
newBlocks = append(newBlocks, user)
}
if assistant != "" {
newBlocks = append(newBlocks, assistant)
}
}

return cancel, &SymbolMatcher{
Find: "<|",
H: func(index int, content string) (state int, result string) {
Expand Down
25 changes: 25 additions & 0 deletions internal/common/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
"pad", // bing中使用的标记:填充引导对话,尝试避免道歉
"notebook", // notebook模式
"histories",
"char_sequences", // 角色序列映射
"tool",
})
)
Expand Down Expand Up @@ -678,6 +679,30 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
continue
}

// 角色序列映射
if node.t == XML_TYPE_X && node.tag == "char_sequences" {
var (
user = ""
assistant = ""
)
if e, ok := node.attr["user"]; ok {
if o, k := e.(string); k {
user = o
}
}
if e, ok := node.attr["assistant"]; ok {
if o, k := e.(string); k {
assistant = o
}
}
ctx.Set(vars.GinCharSequences, pkg.Keyv[string]{
"user": user,
"assistant": assistant,
})
clean(content[node.index:node.end])
continue
}

if node.t == XML_TYPE_X && node.tag == "tool" {
id := "-1"
if e, ok := node.attr["id"]; ok {
Expand Down
40 changes: 36 additions & 4 deletions internal/plugin/llm/bing/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,37 @@ func mergeMessages(ctx *gin.Context, pad bool, max int, completion pkg.ChatCompl
}
}

var (
user = ""
assistant = ""
)

{
keyv, ok := common.GetGinValue[pkg.Keyv[string]](ctx, vars.GinCharSequences)
if ok {
user = keyv.GetString("user")
assistant = keyv.GetString("assistant")
}

if user == "" {
user = "<|user|>"
}
if assistant == "" {
assistant = "<|assistant|>"
}
}

tor := func(r string) string {
switch r {
case "user":
return user
case "assistant":
return assistant
default:
return "<|" + r + "|>"
}
}

// 合并历史对话
iterator := func(opts struct {
Previous string
Expand All @@ -139,7 +170,8 @@ func mergeMessages(ctx *gin.Context, pad bool, max int, completion pkg.ChatCompl
if e != nil {
return nil, logger.WarpError(e)
}
opts.Buffer.WriteString(fmt.Sprintf("<|%s|>\n%s\n<|end|>", role, content))

opts.Buffer.WriteString(fmt.Sprintf("%s\n%s\n<|end|>", tor(role), content))
if condition(role) != condition(opts.Next) {
result = append(result, edge.BuildUserMessage(opts.Buffer.String()))
opts.Buffer.Reset()
Expand All @@ -154,7 +186,7 @@ func mergeMessages(ctx *gin.Context, pad bool, max int, completion pkg.ChatCompl
return
}

opts.Buffer.WriteString(fmt.Sprintf("<|%s|>\n%s\n<|end|>", role, opts.Message["content"]))
opts.Buffer.WriteString(fmt.Sprintf("%s\n%s\n<|end|>", tor(role), opts.Message["content"]))
return
}

Expand All @@ -165,7 +197,7 @@ func mergeMessages(ctx *gin.Context, pad bool, max int, completion pkg.ChatCompl
opts.Buffer.Reset()
}

opts.Buffer.WriteString(fmt.Sprintf("<|%s|>\n%s\n<|end|>", role, opts.Message["content"]))
opts.Buffer.WriteString(fmt.Sprintf("%s\n%s\n<|end|>", tor(role), opts.Message["content"]))
result = append(result, edge.BuildSwitchMessage(condition(role), opts.Buffer.String()))
return
}
Expand All @@ -187,7 +219,7 @@ func mergeMessages(ctx *gin.Context, pad bool, max int, completion pkg.ChatCompl
if message["author"] == "user" {
newMessages = append(newMessages[:pos], newMessages[pos+1:]...)
text = strings.TrimSpace(message["text"].(string))
text = strings.TrimLeft(text, "<|user|>")
text = strings.TrimLeft(text, tor("user"))
text = strings.TrimRight(text, "<|end|>")
break
}
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/cohere/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (API) Completion(ctx *gin.Context) {
// TODO - 官方Go库出了,后续修改
if notebook {
//toolObject = coh.ToolObject{}
message = mergeMessages(completion.Messages)
message = mergeMessages(ctx, completion.Messages)
ctx.Set(ginTokens, common.CalcTokens(message))
chat = coh.New(cookie, completion.Temperature, completion.Model, false)
chat.TopK(completion.TopK)
Expand Down
37 changes: 34 additions & 3 deletions internal/plugin/llm/cohere/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func waitResponse(ctx *gin.Context, matchers []common.Matcher, chatResponse chan
return
}

func mergeMessages(messages []pkg.Keyv[interface{}]) (content string) {
func mergeMessages(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (content string) {
condition := func(expr string) string {
switch expr {
case "system", "user", "assistant", "function", "tool":
Expand All @@ -133,6 +133,37 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (content string) {
}
}

var (
user = ""
assistant = ""
)

{
keyv, ok := common.GetGinValue[pkg.Keyv[string]](ctx, vars.GinCharSequences)
if ok {
user = keyv.GetString("user")
assistant = keyv.GetString("assistant")
}

if user == "" {
user = "<|user|>"
}
if assistant == "" {
assistant = "<|assistant|>"
}
}

tor := func(r string) string {
switch r {
case "user":
return user
case "assistant":
return assistant
default:
return "<|" + r + "|>"
}
}

iterator := func(opts struct {
Previous string
Next string
Expand All @@ -155,7 +186,7 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (content string) {
opts.Buffer.WriteString(fmt.Sprintf(opts.Message["content"]))
messages = []map[string]string{
{
"role": condition(role),
"role": tor(condition(role)),
"content": opts.Buffer.String(),
},
}
Expand All @@ -166,7 +197,7 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (content string) {
// 尾部添加一个assistant空消息
if newMessages[len(newMessages)-1]["role"] != "assistant" {
newMessages = append(newMessages, map[string]string{
"role": "assistant",
"role": tor("assistant"),
"content": "",
})
}
Expand Down
47 changes: 41 additions & 6 deletions internal/plugin/llm/coze/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/bincooo/chatgpt-adapter/internal/common"
"github.com/bincooo/chatgpt-adapter/internal/gin.handler/response"
"github.com/bincooo/chatgpt-adapter/internal/plugin"
"github.com/bincooo/chatgpt-adapter/internal/vars"
"github.com/bincooo/chatgpt-adapter/logger"
"github.com/bincooo/chatgpt-adapter/pkg"
"github.com/bincooo/coze-api"
Expand Down Expand Up @@ -123,8 +124,37 @@ func (API) Completion(ctx *gin.Context) {
proxies = ctx.GetString("proxies")
completion = common.GetGinCompletion(ctx)
matchers = common.GetGinMatchers(ctx)

user = ""
assistant = ""
)

{
keyv, ok := common.GetGinValue[pkg.Keyv[string]](ctx, vars.GinCharSequences)
if ok {
user = keyv.GetString("user")
assistant = keyv.GetString("assistant")
}

if user == "" {
user = "user"
}
if assistant == "" {
assistant = "assistant"
}
}

tor := func(r string) string {
switch r {
case "user":
return user
case "assistant":
return assistant
default:
return r
}
}

if plugin.NeedToToolCall(ctx) {
if completeToolCalls(ctx, cookie, proxies, completion) {
return
Expand All @@ -151,7 +181,7 @@ func (API) Completion(ctx *gin.Context) {

var lock *common.ExpireLock
if mode == 'o' {
l, e := draftBot(ctx, pMessages, chat, completion)
l, e := draftBot(ctx, pMessages[0], chat, completion)
if e != nil {
response.Error(ctx, e.Code, e.Err)
return
Expand All @@ -164,7 +194,13 @@ func (API) Completion(ctx *gin.Context) {
query = pMessages[len(pMessages)-1].Content
chat.WebSdk(chat.TransferMessages(pMessages[:len(pMessages)-1]))
} else {
query = coze.MergeMessages(pMessages)
var newP []coze.Message
for _, message := range pMessages {
message.Role = tor(message.Role)
newP = append(newP, message)
}
query = coze.MergeMessages(newP)
query = query[:len(query)-13] + "<|" + tor("assistant") + "|>"
}

chatResponse, err := chat.Reply(common.GetGinContext(ctx), coze.Text, query)
Expand Down Expand Up @@ -193,11 +229,10 @@ func (API) Completion(ctx *gin.Context) {
}

// return true 终止
func draftBot(ctx *gin.Context, pMessages []coze.Message, chat coze.Chat, completion pkg.ChatCompletion) (eLock *common.ExpireLock, emitErr *emit.Error) {
func draftBot(ctx *gin.Context, systemMessage coze.Message, chat coze.Chat, completion pkg.ChatCompletion) (eLock *common.ExpireLock, emitErr *emit.Error) {
var system string
message := pMessages[0]
if message.Role == "system" {
system = message.Content
if systemMessage.Role == "system" {
system = systemMessage.Content
}

var value map[string]interface{}
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/coze/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func completeToolCalls(ctx *gin.Context, cookie, proxies string, completion pkg.
chat.Session(plugin.HTTPClient)
var lock *common.ExpireLock
if mode == 'o' {
l, e := draftBot(ctx, pMessages, chat, completion)
l, e := draftBot(ctx, pMessages[0], chat, completion)
if e != nil {
return "", logger.WarpError(e.Err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/lmsys/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (API) Completion(ctx *gin.Context) {
}
}

newMessages := mergeMessages(completion.Messages)
newMessages := mergeMessages(ctx, completion.Messages)
ctx.Set(ginTokens, common.CalcTokens(newMessages))
retry := 3
label:
Expand Down
Loading

0 comments on commit 58d85a8

Please sign in to comment.