Skip to content

Commit

Permalink
feat: 工具调用添加任务拆解功能
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed May 20, 2024
1 parent 9ab34a5 commit 410abff
Show file tree
Hide file tree
Showing 11 changed files with 296 additions and 67 deletions.
3 changes: 3 additions & 0 deletions flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ flag: tool
attribute:
id: (string) 指定tool_function里的name值,默认-1
tasks: (bool) 是否任务拆解,默认 false
使用示例
<tool id="xxx" />
<tool id="xxx" tasks />
```
78 changes: 78 additions & 0 deletions internal/agent/com.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,83 @@
package agent

const ToolTasks = `{{- range $index, $value := .pMessages}}
{{- if eq $value.role "tool" }}
<|tool|>
TOOL_RESPONSE:
name: "{{ $value.name }}"
description: "{{ ToolDesc $value.name }}"
output: {{ $value.content }}
<|end|>
{{- else }}
<|{{$value.role}}|>
{{$value.content}}
<|end|>
{{end -}}
{{end}}
你是一个智能机器人,你拥有专注于拆解多个任务的能力。有时候,你可以依赖工具的运行结果,来更准确的回答用户。
请你根据用户请求,拆解出3个以内的子任务。在完成拆解过程中,USER代表用户的输入,TOOL_RESPONSE代表工具运行结果。ASSISTANT 代表你的输出。
你的每次输出都必须以0,1开头,代表是否需要拆解任务:
0: 无拆解任务。
1: [task1, task2, task3]。
例如:
USER: 你好呀 <|end|>
ANSWER: 0: 无拆解任务 <|end|>
USER: 今天杭州的天气如何 <|end|>
ANSWER: 1: [{"toolId": "testToolId", "task": "今天杭州的天气"}] <|end|>
TOOL_RESPONSE: """
晴天......
"""
USER: 今天杭州的天气适合去哪里玩? <|end|>
ANSWER: 1: [{"toolId": "testToolId", "task": "今天杭州的天气"}, {"toolId": "testToolId2", "task": "杭州合适去哪里游玩"}] <|end|>
TOOL_RESPONSE: """
晴天. 西湖、灵隐寺、千岛湖……
"""
ANSWER: 0: 无拆解任务 <|end|>
USER: 获取深圳天气并发送给QQ群组中 <|end|>
ANSWER: 1: [{"toolId": "testToolId", "task": "深圳的天气"}, {"toolId": "testToolId2", "task": "发送QQ群组信息"}] <|end|>
现在,我们开始吧!下面是你本次可以使用的工具:
"""
[
{{- range $index, $value := .tools}}
{{- if eq $value.type "function" }}
{
"toolId": "{{$value.function.id}}",
"description": "{{$value.function.description}}",
"parameters": {
"type": "object",
"properties": {
{{- range $key, $v := $value.function.parameters.properties}}
"{{$key}}": {
"type": "{{$v.type}}",
"description": "{{$v.description}}"
}
{{- end }}
}
},
"required": [{{Join $value.function.parameters.required ", " }}]
},
{{- end -}}
{{- end}}
]
"""
阅读上下文,不要重复选中相同的工具。
下面是正式的对话内容:
USER: {{.content}}
ANSWER: `

const ToolCall = `{{- range $index, $value := .pMessages}}
{{- if eq $value.role "tool" }}
<|tool|>
Expand Down
10 changes: 9 additions & 1 deletion internal/common/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@ func StringCombiner[T any](messages []T, iter func(message T) string) string {
}

func NeedToToolCall(ctx *gin.Context) bool {
tool := ctx.GetString("tool")
var tool = "-1"
if t, ok := ctx.Get("tool"); ok {
keyv := t.(pkg.Keyv[interface{}])
tool = keyv.GetString("id")
if tool == "-1" && keyv.Is("tasks", true) {
tool = "tasks"
}
}

completion := GetGinCompletion(ctx)
messageL := len(completion.Messages)
if messageL == 0 {
Expand Down
13 changes: 10 additions & 3 deletions internal/common/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,10 +680,17 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
id = o
}
}
clean(content[node.index:node.end])
if id != "-1" {
ctx.Set("tool", id)
tasks := false
if e, ok := node.attr["tasks"]; ok {
if o, k := e.(bool); k {
tasks = o
}
}
clean(content[node.index:node.end])
ctx.Set("tool", pkg.Keyv[interface{}]{
"id": id,
"tasks": tasks,
})
continue
}
}
Expand Down
6 changes: 3 additions & 3 deletions internal/middle/bing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (API) Completion(ctx *gin.Context) {

// 清理多余的标签
var cancel chan error
cancel, matchers = addMatchers(matchers)
cancel, matchers = joinMatchers(ctx, matchers)
ctx.Set("tokens", tokens)
chatResponse, err := chat.Reply(ctx.Request.Context(), currMessage, pMessages)
if err != nil {
Expand All @@ -92,7 +92,7 @@ func (API) Completion(ctx *gin.Context) {
waitResponse(ctx, matchers, cancel, chatResponse, completion.Stream)
}

func addMatchers(matchers []pkg.Matcher) (chan error, []pkg.Matcher) {
func joinMatchers(ctx *gin.Context, matchers []pkg.Matcher) (chan error, []pkg.Matcher) {
// 清理 [1]、[2] 标签
// 清理 [^1^]、[^2^] 标签
// 清理 [^1^ 标签
Expand Down Expand Up @@ -151,7 +151,7 @@ func addMatchers(matchers []pkg.Matcher) (chan error, []pkg.Matcher) {
})

// 自定义标记块中断
cancel, matcher := pkg.NewCancelMather()
cancel, matcher := pkg.NewCancelMather(ctx)
matchers = append(matchers, matcher)
return cancel, matchers
}
2 changes: 1 addition & 1 deletion internal/middle/coze/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (API) Completion(ctx *gin.Context) {
}

// 自定义标记块中断
cancel, matcher := pkg.NewCancelMather()
cancel, matcher := pkg.NewCancelMather(ctx)
matchers = append(matchers, matcher)

waitResponse(ctx, matchers, cancel, chatResponse, completion.Stream)
Expand Down
6 changes: 3 additions & 3 deletions internal/middle/lmsys/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ label:
return
}

cancel, matchers := joinMatchers(matchers)
cancel, matchers := joinMatchers(ctx, matchers)
waitResponse(ctx, matchers, ch, cancel, completion.Stream)
}

func joinMatchers(matchers []pkg.Matcher) (chan error, []pkg.Matcher) {
func joinMatchers(ctx *gin.Context, matchers []pkg.Matcher) (chan error, []pkg.Matcher) {
// 自定义标记块中断
cancel, matcher := pkg.NewCancelMather()
cancel, matcher := pkg.NewCancelMather(ctx)
matchers = append(matchers, matcher)

// 违反内容中断并返回错误1
Expand Down
19 changes: 12 additions & 7 deletions internal/middle/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/bincooo/chatgpt-adapter/v2/internal/common"
"github.com/bincooo/chatgpt-adapter/v2/internal/vars"
"github.com/bincooo/chatgpt-adapter/v2/pkg"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -156,11 +157,11 @@ func SSEResponse(ctx *gin.Context, model, content string, created int64) {
response.Choices[0].FinishReason = &finishReason
}

event(ctx.Writer, response)
event(ctx, response)

if done {
time.Sleep(100 * time.Millisecond)
event(ctx.Writer, "[DONE]")
event(ctx, "[DONE]")
}
}

Expand Down Expand Up @@ -228,21 +229,21 @@ func SSEToolCallResponse(ctx *gin.Context, model, name, args string, created int
ToolCalls: []pkg.Keyv[interface{}]{toolCall},
}

event(ctx.Writer, response)
event(ctx, response)

delete(toolCall, "id")
delete(toolCall, "type")
toolCall["function"] = map[string]string{"arguments": args}
response.Choices[0].Delta.ToolCalls[0] = toolCall
response.Choices[0].Delta.Role = ""
event(ctx.Writer, response)
event(ctx, response)

response.Choices[0].FinishReason = &toolCalls
response.Choices[0].Delta = nil
response.Usage = usage
event(ctx.Writer, response)
event(ctx, response)

event(ctx.Writer, "[DONE]")
event(ctx, "[DONE]")
}

func NotSSEHeader(ctx *gin.Context) bool {
Expand All @@ -265,13 +266,15 @@ func setSSEHeader(ctx *gin.Context) {
}
}

func event(w gin.ResponseWriter, data interface{}) {
func event(ctx *gin.Context, data interface{}) {
w := ctx.Writer
str, ok := data.(string)
if ok {
layout := "data: %s\n\n"
_, err := fmt.Fprintf(w, layout, str)
if err != nil {
logrus.Error(err)
ctx.Set(vars.GinClose, true)
return
}

Expand All @@ -282,12 +285,14 @@ func event(w gin.ResponseWriter, data interface{}) {
marshal, err := json.Marshal(data)
if err != nil {
logrus.Error(err)
ctx.Set(vars.GinClose, true)
return
}

_, err = fmt.Fprintf(w, "data: %s\n\n", marshal)
if err != nil {
logrus.Error(err)
ctx.Set(vars.GinClose, true)
return
}
w.Flush()
Expand Down
Loading

0 comments on commit 410abff

Please sign in to comment.