Skip to content

Commit

Permalink
feat: 1.修改google模型名称,添加flash模型;2.添加tool增强标签,用于工具选择默认
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed May 19, 2024
1 parent 6c3c228 commit a37e593
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 66 deletions.
10 changes: 10 additions & 0 deletions flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,14 @@ flag: histories
"model": "coze",
"stream": false
}
```

#### tools 工具 开启 默认选中模式,作用是让工具选择在不匹配时默认选择一个,仅支持无参工具
```text
flag: tool
attribute:
id: (string) 指定tool_function里的name值,默认-1
<tool id="xxx">
```
45 changes: 28 additions & 17 deletions internal/agent/com.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,44 @@ output: {{$value.content}}
{{end -}}
{{end}}
<Instruction>
你是一个智能机器人,除了可以回答用户问题外,你还掌握工具的使用能力。有时候,你可以依赖工具的运行结果,来更准确的回答用户。
你是一个智能机器人,你专注于选择工具的给用户使用的能力。有时候,你可以依赖工具的运行结果,来更准确的回答用户。
工具使用了 JSON Schema 的格式声明,其中 toolId 是工具的 description 是工具的描述,parameters 是工具的参数,包括参数的类型和描述,required 是必填参数的列表。
请你根据工具描述,决定回答问题或是使用工具。在完成任务过程中,USER代表用户的输入,TOOL_RESPONSE代表工具运行结果。ASSISTANT 代表你的输出。
{{- if eq .toolDef "-1" }}
你的每次输出都必须以0,1开头,代表是否需要调用工具:
0: 不使用工具,直接回答内容
0: 不使用工具。
1: 使用工具,返回工具调用的参数。
{{- else }}
你的本次输必须以1开头,代表是否需要调用工具:
0: 不使用工具。
1: 使用工具,返回工具调用的参数。
{{- end }}
例如:
USER: 你好呀
ANSWER: 0: 你好,有什么可以帮助你的么?
USER: 今天杭州的天气如何
ANSWER: 1: {"toolId":"testToolId",arguments:{"city": "杭州"}}
USER: 你好呀 <|end|>
{{- if eq .toolDef "-1" }}
ANSWER: 0: <|end|>
{{- else }}
ANSWER: 1: {"toolId":"{{.toolDef}}",arguments:{}} <|end|>
{{- end }}
USER: 今天杭州的天气如何 <|end|>
ANSWER: 1: {"toolId":"testToolId",arguments:{"city": "杭州"}} <|end|>
TOOL_RESPONSE: """
晴天......
"""
ANSWER: 0: 今天杭州是晴天。
USER: 今天杭州的天气适合去哪里玩?
ANSWER: 1: {"toolId":"testToolId2",arguments:{"query": "杭州 天气 去哪里玩"}}
USER: 今天杭州的天气适合去哪里玩? <|end|>
ANSWER: 1: {"toolId":"testToolId2",arguments:{"query": "杭州 天气 去哪里玩"}} <|end|>
TOOL_RESPONSE: """
晴天. 西湖、灵隐寺、千岛湖……
"""
ANSWER: 0: 今天杭州是晴天,适合去西湖、灵隐寺、千岛湖等地玩。
</Instruction>
{{- if eq .toolDef "-1" }}
ANSWER: 0: <|end|>
{{- else }}
ANSWER: 1: {"toolId":"{{.toolDef}}",arguments:{}} <|end|>
{{- end }}
现在,我们开始吧!下面是你本次可以使用的工具:
Expand All @@ -58,13 +69,13 @@ ANSWER: 0: 今天杭州是晴天,适合去西湖、灵隐寺、千岛湖等地
"type": "{{$v.type}}",
"description": "{{$v.description}}"
}
{{end -}}
{{- end }}
}
},
"required": {{$value.function.parameters.required}}
"required": [{{join $value.function.parameters.required ", " }}]
},
{{end -}}
{{end -}}
{{- end -}}
{{- end}}
]
"""
Expand Down
15 changes: 15 additions & 0 deletions internal/common/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
"pad", // bing中使用的标记:填充引导对话,尝试避免道歉
"notebook", // notebook模式
"histories",
"tool",
})
)

Expand Down Expand Up @@ -671,6 +672,20 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
}
continue
}

if node.t == XML_TYPE_X && node.tag == "tool" {
id := "-1"
if e, ok := node.attr["id"]; ok {
if o, k := e.(string); k {
id = o
}
}
clean(content[node.index:node.end])
if id != "-1" {
ctx.Set("tool", id)
}
continue
}
}
}

Expand Down
12 changes: 8 additions & 4 deletions internal/middle/gemini/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
)

const MODEL = "gemini"
const GOOGLE_BASE = "https://generativelanguage.googleapis.com/%s?alt=sse&key=%s"
const login = "http://127.0.0.1:8081/v1/login"

var (
Expand Down Expand Up @@ -55,7 +54,7 @@ type API struct {

func (API) Match(_ *gin.Context, model string) bool {
switch model {
case "gemini-1.0", "gemini-1.5":
case "gemini-1.0-pro-latest", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest":
return true
default:
return false
Expand All @@ -65,12 +64,17 @@ func (API) Match(_ *gin.Context, model string) bool {
func (API) Models() []middle.Model {
return []middle.Model{
{
Id: "gemini-1.0",
Id: "gemini-1.0-pro-latest",
Object: "model",
Created: 1686935002,
By: "gemini-adapter",
}, {
Id: "gemini-1.5",
Id: "gemini-1.5-pro-latest",
Object: "model",
Created: 1686935002,
By: "gemini-adapter",
}, {
Id: "gemini-1.5-flash-latest",
Object: "model",
Created: 1686935002,
By: "gemini-adapter",
Expand Down
42 changes: 19 additions & 23 deletions internal/middle/gemini/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strings"
)

const GOOGLE_BASE_FORMAT = "https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s"

type funcDecl struct {
Name string `json:"name"`
Description string `json:"description"`
Expand All @@ -25,35 +27,29 @@ type funcDecl struct {
}

// 构建请求,返回响应
func build(ctx context.Context, proxies, token string, messages []map[string]interface{}, req pkg.ChatCompletion) (*http.Response, error) {
var (
burl = fmt.Sprintf(GOOGLE_BASE, "v1beta/models/gemini-1.0-pro-latest:streamGenerateContent", token)
)

if req.Model == "gemini-1.5" {
burl = fmt.Sprintf(GOOGLE_BASE, "v1beta/models/gemini-1.5-pro-latest:streamGenerateContent", token)
}
func build(ctx context.Context, proxies, token string, messages []map[string]interface{}, completion pkg.ChatCompletion) (*http.Response, error) {
gURL := fmt.Sprintf(GOOGLE_BASE_FORMAT, completion.Model, token)

if req.Temperature < 0.1 {
req.Temperature = 1
if completion.Temperature < 0.1 {
completion.Temperature = 1
}

if req.MaxTokens == 0 {
req.MaxTokens = 2048
if completion.MaxTokens == 0 {
completion.MaxTokens = 2048
}

if req.TopK == 0 {
req.TopK = 100
if completion.TopK == 0 {
completion.TopK = 100
}

if req.TopP == 0 {
req.TopP = 0.95
if completion.TopP == 0 {
completion.TopP = 0.95
}

// 参数基本与openai对齐
_funcDecls := make([]funcDecl, 0)
if toolsL := len(req.Tools); toolsL > 0 {
for _, v := range req.Tools {
if toolsL := len(completion.Tools); toolsL > 0 {
for _, v := range completion.Tools {
kv := v.GetKeyv("function").GetKeyv("parameters")
required, ok := kv.Get("required")
if !ok {
Expand Down Expand Up @@ -93,10 +89,10 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
payload := map[string]any{
"contents": messages, // [ { role: user, parts: [ { text: 'xxx' } ] } ]
"generationConfig": map[string]any{
"topK": req.TopK,
"topP": req.TopP,
"temperature": req.Temperature, // 0.8
"maxOutputTokens": req.MaxTokens,
"topK": completion.TopK,
"topP": completion.TopP,
"temperature": completion.Temperature, // 0.8
"maxOutputTokens": completion.MaxTokens,
"stopSequences": []string{},
},
// 安全级别
Expand Down Expand Up @@ -137,7 +133,7 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
res, err := emit.ClientBuilder().
Proxies(proxies).
Context(ctx).
POST(burl).
POST(gURL).
JHeader().
Bytes(marshal).
Do()
Expand Down
1 change: 1 addition & 0 deletions internal/middle/lmsys/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ func fetchCookies(ctx context.Context, proxies string) (cookies string) {
DoS(http.StatusOK)
if err != nil {
var e emit.Error
logrus.Errorf("retry[%d]: %v", index, err)
if errors.As(err, &e) && e.Code == 429 {
return
}
Expand Down
67 changes: 45 additions & 22 deletions internal/middle/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middle

import (
"encoding/json"
"fmt"
"github.com/bincooo/chatgpt-adapter/v2/internal/agent"
"github.com/bincooo/chatgpt-adapter/v2/internal/common"
"github.com/bincooo/chatgpt-adapter/v2/internal/vars"
Expand All @@ -12,31 +13,56 @@ import (
"time"
)

func buildTemplate(tools []pkg.Keyv[interface{}], messages []pkg.Keyv[interface{}], template string, max int) (message string, err error) {
pMessages := messages
func buildTemplate(ctx *gin.Context, completion pkg.ChatCompletion, template string, max int) (message string, err error) {
toolDef := ctx.GetString("tool")
if toolDef == "" {
toolDef = "-1"
}

pMessages := completion.Messages
content := "continue"
if messageL := len(messages); messageL > 0 && messages[messageL-1]["role"] == "user" {
content = messages[messageL-1].GetString("content")
if messageL := len(pMessages); messageL > 0 && pMessages[messageL-1]["role"] == "user" {
content = pMessages[messageL-1].GetString("content")
if max == 0 {
pMessages = make([]pkg.Keyv[interface{}], 0)
} else if max > 0 && messageL > max {
pMessages = messages[messageL-max : messageL-1]
pMessages = pMessages[messageL-max : messageL-1]
} else {
pMessages = messages[:messageL-1]
pMessages = pMessages[:messageL-1]
}
}

for _, t := range tools {
if !t.GetKeyv("function").Has("id") {
t.GetKeyv("function").Set("id", common.RandStr(5))
for _, t := range completion.Tools {
id := common.RandStr(5)
fn := t.GetKeyv("function")
if !fn.Has("id") {
t.GetKeyv("function").Set("id", id)
} else {
id = fn.GetString("id")
}

if toolDef != "-1" && fn.Has("name") {
if toolDef == fn.GetString("name") {
toolDef = id
}
}
}

parser := templateBuilder().
Vars("tools", tools).
Vars("toolDef", toolDef).
Vars("tools", completion.Tools).
Vars("pMessages", pMessages).
Vars("content", content).
Do()
Func("join", func(slice []interface{}, sep string) string {
if len(slice) == 0 {
return ""
}
var result []string
for _, v := range slice {
result = append(result, fmt.Sprintf("\"%v\"", v))
}
return strings.Join(result, sep)
}).Do()
return parser(template)
}

Expand All @@ -45,11 +71,8 @@ func buildTemplate(tools []pkg.Keyv[interface{}], messages []pkg.Keyv[interface{
// return:
// bool > 是否执行了工具
// error > 执行异常
func CompleteToolCalls(ctx *gin.Context, req pkg.ChatCompletion, callback func(message string) (string, error)) (bool, error) {
message, err := buildTemplate(
req.Tools,
req.Messages,
agent.ToolCall, 5)
func CompleteToolCalls(ctx *gin.Context, completion pkg.ChatCompletion, callback func(message string) (string, error)) (bool, error) {
message, err := buildTemplate(ctx, completion, agent.ToolCall, 5)
if err != nil {
return false, err
}
Expand All @@ -64,10 +87,10 @@ func CompleteToolCalls(ctx *gin.Context, req pkg.ChatCompletion, callback func(m
ctx.Set(vars.GinCompletionUsage, common.CalcUsageTokens(content, previousTokens))

// 解析参数
return parseToToolCall(ctx, content, req), nil
return parseToToolCall(ctx, content, completion), nil
}

func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) bool {
func parseToToolCall(ctx *gin.Context, content string, completion pkg.ChatCompletion) bool {
j := ""
created := time.Now().Unix()
slice := strings.Split(content, "TOOL_RESPONSE")
Expand All @@ -87,7 +110,7 @@ func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) b
}

name := ""
for _, t := range req.Tools {
for _, t := range completion.Tools {
if strings.Contains(j, t.GetKeyv("function").GetString("id")) {
name = t.GetKeyv("function").GetString("name")
break
Expand All @@ -113,11 +136,11 @@ func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) b
}
bytes, _ := json.Marshal(obj)

if req.Stream {
SSEToolCallResponse(ctx, req.Model, name, string(bytes), created)
if completion.Stream {
SSEToolCallResponse(ctx, completion.Model, name, string(bytes), created)
return true
} else {
ToolCallResponse(ctx, req.Model, name, string(bytes))
ToolCallResponse(ctx, completion.Model, name, string(bytes))
return true
}
}
Expand Down

0 comments on commit a37e593

Please sign in to comment.