From fdd98ad3b050c64cb091d80fc1e0f540c48e648c Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Fri, 24 Apr 2026 14:41:09 +0800 Subject: [PATCH] feat: doubao qwen response native mode support --- core/relay/adaptor/ali/adaptor.go | 71 ++++++- core/relay/adaptor/ali/adaptor_test.go | 149 ++++++++++++++ core/relay/adaptor/doubao/main.go | 69 ++++++- core/relay/adaptor/doubao/main_test.go | 193 ++++++++++++++++++ core/relay/adaptor/zhipu/adaptor.go | 110 +++++++++- core/relay/adaptor/zhipu/adaptor_test.go | 110 ++++++++++ core/relay/adaptor/zhipucoding/adaptor.go | 59 +++++- .../relay/adaptor/zhipucoding/adaptor_test.go | 146 +++++++++++++ 8 files changed, 886 insertions(+), 21 deletions(-) create mode 100644 core/relay/adaptor/ali/adaptor_test.go create mode 100644 core/relay/adaptor/doubao/main_test.go create mode 100644 core/relay/adaptor/zhipu/adaptor_test.go create mode 100644 core/relay/adaptor/zhipucoding/adaptor_test.go diff --git a/core/relay/adaptor/ali/adaptor.go b/core/relay/adaptor/ali/adaptor.go index 8c4cb3bc..7f15cb8c 100644 --- a/core/relay/adaptor/ali/adaptor.go +++ b/core/relay/adaptor/ali/adaptor.go @@ -41,7 +41,12 @@ func (a *Adaptor) SupportMode(m mode.Mode) bool { m == mode.AudioTranscription || m == mode.AudioTranslation || m == mode.Anthropic || - m == mode.Gemini + m == mode.Gemini || + m == mode.Responses || + m == mode.ResponsesGet || + m == mode.ResponsesDelete || + m == mode.ResponsesCancel || + m == mode.ResponsesInputItems } func (a *Adaptor) GetRequestURL( @@ -132,6 +137,56 @@ func (a *Adaptor) GetRequestURL( Method: http.MethodPost, URL: url, }, nil + case mode.Responses: + url, err := url.JoinPath(u, "/compatible-mode/v1/responses") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: url, + }, nil + case mode.ResponsesGet: + url, err := url.JoinPath(u, "/compatible-mode/v1/responses", meta.ResponseID) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodGet, + URL: url, + }, nil + case mode.ResponsesDelete: + url, err := url.JoinPath(u, "/compatible-mode/v1/responses", meta.ResponseID) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodDelete, + URL: url, + }, nil + case mode.ResponsesCancel: + url, err := url.JoinPath(u, "/compatible-mode/v1/responses", meta.ResponseID, "cancel") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: url, + }, nil + case mode.ResponsesInputItems: + url, err := url.JoinPath(u, "/compatible-mode/v1/responses", meta.ResponseID, "input_items") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodGet, + URL: url, + }, nil default: return adaptor.RequestURL{}, fmt.Errorf("unsupported mode: %s", meta.Mode) } @@ -173,6 +228,12 @@ func (a *Adaptor) ConvertRequest( return anthropic.ConvertRequest(meta, req) case mode.Gemini: return openai.ConvertGeminiRequest(meta, req) + case mode.Responses, + mode.ResponsesGet, + mode.ResponsesDelete, + mode.ResponsesCancel, + mode.ResponsesInputItems: + return openai.ConvertRequest(meta, store, req) default: return adaptor.ConvertResult{}, fmt.Errorf("unsupported mode: %s", meta.Mode) } @@ -223,6 +284,12 @@ func (a *Adaptor) DoResponse( return openai.GeminiStreamHandler(meta, c, resp) } return openai.GeminiHandler(meta, c, resp) + case mode.Responses, + mode.ResponsesGet, + mode.ResponsesDelete, + mode.ResponsesCancel, + mode.ResponsesInputItems: + return openai.DoResponse(meta, store, c, resp) default: return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIErrorWithMessage( fmt.Sprintf("unsupported mode: %s", meta.Mode), @@ -234,7 +301,7 @@ func (a *Adaptor) DoResponse( func (a *Adaptor) Metadata() adaptor.Metadata { return adaptor.Metadata{ - Readme: "OpenAI compatibility\nNetwork search metering support\nRerank support: https://help.aliyun.com/zh/model-studio/text-rerank-api\nSTT support: https://help.aliyun.com/zh/model-studio/sambert-speech-synthesis/\nAnthropic support: /api/v2/apps/claude-code-proxy\nGemini support", + Readme: "OpenAI compatibility\nNative Responses API support\nNetwork search metering support\nRerank support: https://help.aliyun.com/zh/model-studio/text-rerank-api\nSTT support: https://help.aliyun.com/zh/model-studio/sambert-speech-synthesis/\nAnthropic support: /api/v2/apps/claude-code-proxy\nGemini support", Models: ModelList, } } diff --git a/core/relay/adaptor/ali/adaptor_test.go b/core/relay/adaptor/ali/adaptor_test.go new file mode 100644 index 00000000..f63d45e0 --- /dev/null +++ b/core/relay/adaptor/ali/adaptor_test.go @@ -0,0 +1,149 @@ +//nolint:testpackage +package ali + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" +) + +func TestAdaptorSupportModeResponses(t *testing.T) { + adaptor := &Adaptor{} + + supportedModes := []mode.Mode{ + mode.Responses, + mode.ResponsesGet, + mode.ResponsesDelete, + mode.ResponsesCancel, + mode.ResponsesInputItems, + } + for _, m := range supportedModes { + if !adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be supported", m) + } + } +} + +func TestAdaptorGetRequestURLResponses(t *testing.T) { + adaptor := &Adaptor{} + channel := &coremodel.Channel{BaseURL: "https://dashscope.aliyuncs.com"} + + tests := []struct { + name string + mode mode.Mode + responseID string + wantMethod string + wantURL string + }{ + { + name: "responses create", + mode: mode.Responses, + wantMethod: http.MethodPost, + wantURL: "https://dashscope.aliyuncs.com/compatible-mode/v1/responses", + }, + { + name: "responses get", + mode: mode.ResponsesGet, + responseID: "resp_123", + wantMethod: http.MethodGet, + wantURL: "https://dashscope.aliyuncs.com/compatible-mode/v1/responses/resp_123", + }, + { + name: "responses delete", + mode: mode.ResponsesDelete, + responseID: "resp_123", + wantMethod: http.MethodDelete, + wantURL: "https://dashscope.aliyuncs.com/compatible-mode/v1/responses/resp_123", + }, + { + name: "responses cancel", + mode: mode.ResponsesCancel, + responseID: "resp_123", + wantMethod: http.MethodPost, + wantURL: "https://dashscope.aliyuncs.com/compatible-mode/v1/responses/resp_123/cancel", + }, + { + name: "responses input items", + mode: mode.ResponsesInputItems, + responseID: "resp_123", + wantMethod: http.MethodGet, + wantURL: "https://dashscope.aliyuncs.com/compatible-mode/v1/responses/resp_123/input_items", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := meta.NewMeta( + channel, + tt.mode, + "qwen-plus", + coremodel.ModelConfig{}, + meta.WithResponseID(tt.responseID), + ) + + got, err := adaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.Method != tt.wantMethod { + t.Fatalf("expected method %s, got %s", tt.wantMethod, got.Method) + } + + if got.URL != tt.wantURL { + t.Fatalf("expected URL %s, got %s", tt.wantURL, got.URL) + } + }) + } +} + +func TestAdaptorConvertRequestResponses(t *testing.T) { + adaptor := &Adaptor{} + m := meta.NewMeta( + nil, + mode.Responses, + "qwen-plus", + coremodel.ModelConfig{}, + ) + + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/responses", + strings.NewReader(`{"model":"qwen-plus","input":"hello","stream":true}`), + ) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + result, err := adaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + body, err := io.ReadAll(result.Body) + if err != nil { + t.Fatalf("failed to read converted body: %v", err) + } + + var responseReq relaymodel.CreateResponseRequest + if err := json.Unmarshal(body, &responseReq); err != nil { + t.Fatalf("failed to unmarshal converted body: %v", err) + } + + if responseReq.Model != "qwen-plus" { + t.Fatalf("expected model qwen-plus, got %s", responseReq.Model) + } + + if !responseReq.Stream { + t.Fatal("expected stream to remain enabled") + } +} diff --git a/core/relay/adaptor/doubao/main.go b/core/relay/adaptor/doubao/main.go index c63b003f..a6722a40 100644 --- a/core/relay/adaptor/doubao/main.go +++ b/core/relay/adaptor/doubao/main.go @@ -23,7 +23,7 @@ func init() { func GetRequestURL(meta *meta.Meta) (adaptor.RequestURL, error) { u := meta.Channel.BaseURL switch meta.Mode { - case mode.ChatCompletions, mode.Anthropic: + case mode.ChatCompletions, mode.Anthropic, mode.Gemini: if strings.HasPrefix(meta.ActualModel, "bot-") { url, err := url.JoinPath(u, "/api/v3/bots/chat/completions") if err != nil { @@ -67,6 +67,56 @@ func GetRequestURL(meta *meta.Meta) (adaptor.RequestURL, error) { Method: http.MethodPost, URL: url, }, nil + case mode.Responses: + url, err := url.JoinPath(u, "/api/v3/responses") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: url, + }, nil + case mode.ResponsesGet: + url, err := url.JoinPath(u, "/api/v3/responses", meta.ResponseID) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodGet, + URL: url, + }, nil + case mode.ResponsesDelete: + url, err := url.JoinPath(u, "/api/v3/responses", meta.ResponseID) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodDelete, + URL: url, + }, nil + case mode.ResponsesCancel: + url, err := url.JoinPath(u, "/api/v3/responses", meta.ResponseID, "cancel") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: url, + }, nil + case mode.ResponsesInputItems: + url, err := url.JoinPath(u, "/api/v3/responses", meta.ResponseID, "input_items") + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodGet, + URL: url, + }, nil default: return adaptor.RequestURL{}, fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) } @@ -85,12 +135,18 @@ func (a *Adaptor) DefaultBaseURL() string { func (a *Adaptor) SupportMode(m mode.Mode) bool { return m == mode.ChatCompletions || m == mode.Anthropic || - m == mode.Embeddings + m == mode.Gemini || + m == mode.Embeddings || + m == mode.Responses || + m == mode.ResponsesGet || + m == mode.ResponsesDelete || + m == mode.ResponsesCancel || + m == mode.ResponsesInputItems } func (a *Adaptor) Metadata() adaptor.Metadata { return adaptor.Metadata{ - Readme: "Doubao / Volcano Engine endpoint\nSupports bot-style models and network search metering fields", + Readme: "Doubao / Volcano Engine endpoint\nSupports bot-style models, native Responses API, Gemini-compatible request conversion, and network search metering fields", Models: ModelList, } } @@ -116,6 +172,8 @@ func (a *Adaptor) ConvertRequest( return openai.ConvertEmbeddingsRequest(meta, req, true) case mode.ChatCompletions: return ConvertChatCompletionsRequest(meta, req) + case mode.Gemini: + return openai.ConvertGeminiRequest(meta, req) default: return openai.ConvertRequest(meta, store, req) } @@ -152,6 +210,11 @@ func (a *Adaptor) DoResponse( resp, embeddingPreHandler, ) + case mode.Gemini: + if utils.IsStreamResponse(resp) { + return openai.GeminiStreamHandler(meta, c, resp) + } + return openai.GeminiHandler(meta, c, resp) default: return openai.DoResponse(meta, store, c, resp) } diff --git a/core/relay/adaptor/doubao/main_test.go b/core/relay/adaptor/doubao/main_test.go new file mode 100644 index 00000000..3b9ebc57 --- /dev/null +++ b/core/relay/adaptor/doubao/main_test.go @@ -0,0 +1,193 @@ +//nolint:testpackage +package doubao + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" +) + +func TestAdaptorSupportMode(t *testing.T) { + adaptor := &Adaptor{} + + supportedModes := []mode.Mode{ + mode.ChatCompletions, + mode.Anthropic, + mode.Gemini, + mode.Embeddings, + mode.Responses, + mode.ResponsesGet, + mode.ResponsesDelete, + mode.ResponsesCancel, + mode.ResponsesInputItems, + } + for _, m := range supportedModes { + if !adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be supported", m) + } + } + + unsupportedModes := []mode.Mode{ + mode.Completions, + mode.ImagesGenerations, + } + for _, m := range unsupportedModes { + if adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be unsupported", m) + } + } +} + +func TestAdaptorGetRequestURL(t *testing.T) { + adaptor := &Adaptor{} + channel := &coremodel.Channel{ + BaseURL: "https://ark.cn-beijing.volces.com", + } + + tests := []struct { + name string + mode mode.Mode + model string + responseID string + wantMethod string + wantURL string + }{ + { + name: "gemini uses chat completions", + mode: mode.Gemini, + model: "doubao-seed-1-6", + wantMethod: http.MethodPost, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/chat/completions", + }, + { + name: "gemini bot uses bot chat completions", + mode: mode.Gemini, + model: "bot-123", + wantMethod: http.MethodPost, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/bots/chat/completions", + }, + { + name: "responses create", + mode: mode.Responses, + model: "doubao-seed-1-6", + wantMethod: http.MethodPost, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/responses", + }, + { + name: "responses get", + mode: mode.ResponsesGet, + model: "doubao-seed-1-6", + responseID: "resp_123", + wantMethod: http.MethodGet, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/responses/resp_123", + }, + { + name: "responses delete", + mode: mode.ResponsesDelete, + model: "doubao-seed-1-6", + responseID: "resp_123", + wantMethod: http.MethodDelete, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/responses/resp_123", + }, + { + name: "responses cancel", + mode: mode.ResponsesCancel, + model: "doubao-seed-1-6", + responseID: "resp_123", + wantMethod: http.MethodPost, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/responses/resp_123/cancel", + }, + { + name: "responses input items", + mode: mode.ResponsesInputItems, + model: "doubao-seed-1-6", + responseID: "resp_123", + wantMethod: http.MethodGet, + wantURL: "https://ark.cn-beijing.volces.com/api/v3/responses/resp_123/input_items", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := meta.NewMeta( + channel, + tt.mode, + tt.model, + coremodel.ModelConfig{}, + meta.WithResponseID(tt.responseID), + ) + + got, err := adaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.Method != tt.wantMethod { + t.Fatalf("expected method %s, got %s", tt.wantMethod, got.Method) + } + + if got.URL != tt.wantURL { + t.Fatalf("expected URL %s, got %s", tt.wantURL, got.URL) + } + }) + } +} + +func TestAdaptorConvertRequestGemini(t *testing.T) { + adaptor := &Adaptor{} + m := meta.NewMeta( + nil, + mode.Gemini, + "doubao-seed-1-6", + coremodel.ModelConfig{}, + ) + + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1beta/models/doubao-seed-1-6:streamGenerateContent", + strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`), + ) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + result, err := adaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + body, err := io.ReadAll(result.Body) + if err != nil { + t.Fatalf("failed to read converted body: %v", err) + } + + var openAIReq relaymodel.GeneralOpenAIRequest + if err := json.Unmarshal(body, &openAIReq); err != nil { + t.Fatalf("failed to unmarshal converted body: %v", err) + } + + if openAIReq.Model != "doubao-seed-1-6" { + t.Fatalf("expected model doubao-seed-1-6, got %s", openAIReq.Model) + } + + if !openAIReq.Stream { + t.Fatal("expected stream to be enabled") + } + + if len(openAIReq.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(openAIReq.Messages)) + } + + if openAIReq.Messages[0].Role != relaymodel.RoleUser { + t.Fatalf("expected user message, got %s", openAIReq.Messages[0].Role) + } +} diff --git a/core/relay/adaptor/zhipu/adaptor.go b/core/relay/adaptor/zhipu/adaptor.go index 74f19e50..e9f52ee2 100644 --- a/core/relay/adaptor/zhipu/adaptor.go +++ b/core/relay/adaptor/zhipu/adaptor.go @@ -1,7 +1,9 @@ package zhipu import ( + "fmt" "net/http" + "net/url" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/model" @@ -10,6 +12,8 @@ import ( "github.com/labring/aiproxy/core/relay/adaptor/registry" "github.com/labring/aiproxy/core/relay/meta" "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" + "github.com/labring/aiproxy/core/relay/utils" ) type Adaptor struct { @@ -26,20 +30,116 @@ func (a *Adaptor) DefaultBaseURL() string { return baseURL } +func (a *Adaptor) SupportMode(m mode.Mode) bool { + return m == mode.ChatCompletions || + m == mode.Completions || + m == mode.AudioTranscription || + m == mode.AudioSpeech || + m == mode.Embeddings || + m == mode.Rerank || + m == mode.Anthropic || + m == mode.Gemini +} + +func (a *Adaptor) GetRequestURL( + meta *meta.Meta, + _ adaptor.Store, + _ *gin.Context, +) (adaptor.RequestURL, error) { + u := meta.Channel.BaseURL + + switch meta.Mode { + case mode.ChatCompletions, mode.Anthropic, mode.Gemini: + return postURL(u, "/chat/completions") + case mode.Completions: + return postURL(u, "/completions") + case mode.AudioTranscription: + return postURL(u, "/audio/transcriptions") + case mode.AudioSpeech: + return postURL(u, "/audio/speech") + case mode.Embeddings: + return postURL(u, "/embeddings") + case mode.Rerank: + return postURL(u, "/rerank") + default: + return adaptor.RequestURL{}, fmt.Errorf("unsupported mode: %s", meta.Mode) + } +} + +func postURL(baseURL, path string) (adaptor.RequestURL, error) { + u, err := url.JoinPath(baseURL, path) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: u, + }, nil +} + +func (a *Adaptor) ConvertRequest( + meta *meta.Meta, + _ adaptor.Store, + req *http.Request, +) (adaptor.ConvertResult, error) { + switch meta.Mode { + case mode.ChatCompletions: + return openai.ConvertChatCompletionsRequest(meta, req, false) + case mode.Completions: + return openai.ConvertCompletionsRequest(meta, req) + case mode.Anthropic: + return openai.ConvertClaudeRequest(meta, req) + case mode.Gemini: + return openai.ConvertGeminiRequest(meta, req) + case mode.AudioTranscription: + return openai.ConvertSTTRequest(meta, req) + case mode.AudioSpeech: + return openai.ConvertTTSRequest(meta, req, "") + case mode.Embeddings: + return openai.ConvertEmbeddingsRequest(meta, req, false) + case mode.Rerank: + return openai.ConvertRerankRequest(meta, req) + default: + return adaptor.ConvertResult{}, fmt.Errorf("unsupported mode: %s", meta.Mode) + } +} + func (a *Adaptor) DoResponse( meta *meta.Meta, store adaptor.Store, c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, ErrorHandler(resp) + } + switch meta.Mode { + case mode.Anthropic: + if utils.IsStreamResponse(resp) { + return openai.ClaudeStreamHandler(meta, c, resp) + } + return openai.ClaudeHandler(meta, c, resp) + case mode.Gemini: + if utils.IsStreamResponse(resp) { + return openai.GeminiStreamHandler(meta, c, resp) + } + return openai.GeminiHandler(meta, c, resp) case mode.Embeddings: return EmbeddingsHandler(c, resp) - default: - if resp.StatusCode != http.StatusOK { - return adaptor.DoResponseResult{}, ErrorHandler(resp) - } + case mode.ChatCompletions, + mode.Completions, + mode.AudioTranscription, + mode.AudioSpeech, + mode.Rerank: return openai.DoResponse(meta, store, c, resp) + default: + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIErrorWithMessage( + fmt.Sprintf("unsupported mode: %s", meta.Mode), + "unsupported_mode", + http.StatusBadRequest, + ) } } @@ -49,7 +149,7 @@ func (a *Adaptor) GetBalance(_ *model.Channel) (float64, error) { func (a *Adaptor) Metadata() adaptor.Metadata { return adaptor.Metadata{ - Readme: "Zhipu AI BigModel API\nOpenAI-compatible endpoint\nSupports Gemini-compatible request conversion", + Readme: "Zhipu AI BigModel API\nOpenAI-compatible endpoint\nSupports Anthropic-compatible and Gemini-compatible request conversion", Models: ModelList, } } diff --git a/core/relay/adaptor/zhipu/adaptor_test.go b/core/relay/adaptor/zhipu/adaptor_test.go new file mode 100644 index 00000000..e13bcea5 --- /dev/null +++ b/core/relay/adaptor/zhipu/adaptor_test.go @@ -0,0 +1,110 @@ +//nolint:testpackage +package zhipu + +import ( + "net/http" + "testing" + + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" +) + +func TestAdaptorSupportMode(t *testing.T) { + adaptor := &Adaptor{} + + supportedModes := []mode.Mode{ + mode.ChatCompletions, + mode.Completions, + mode.AudioTranscription, + mode.AudioSpeech, + mode.Embeddings, + mode.Rerank, + mode.Anthropic, + mode.Gemini, + } + for _, m := range supportedModes { + if !adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be supported", m) + } + } + + unsupportedModes := []mode.Mode{ + mode.Responses, + mode.ResponsesGet, + mode.ImagesGenerations, + mode.Moderations, + mode.AudioTranslation, + } + for _, m := range unsupportedModes { + if adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be unsupported", m) + } + } +} + +func TestAdaptorGetRequestURL(t *testing.T) { + adaptor := &Adaptor{} + channel := &coremodel.Channel{ + BaseURL: "https://open.bigmodel.cn/api/paas/v4", + } + + tests := []struct { + name string + mode mode.Mode + want string + }{ + { + name: "anthropic uses chat completions", + mode: mode.Anthropic, + want: "https://open.bigmodel.cn/api/paas/v4/chat/completions", + }, + { + name: "gemini uses chat completions", + mode: mode.Gemini, + want: "https://open.bigmodel.cn/api/paas/v4/chat/completions", + }, + { + name: "completions uses completions", + mode: mode.Completions, + want: "https://open.bigmodel.cn/api/paas/v4/completions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := meta.NewMeta(channel, tt.mode, "glm-5.1", coremodel.ModelConfig{}) + + got, err := adaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.Method != http.MethodPost { + t.Fatalf("expected method %s, got %s", http.MethodPost, got.Method) + } + + if got.URL != tt.want { + t.Fatalf("expected URL %s, got %s", tt.want, got.URL) + } + + if m.Mode != tt.mode { + t.Fatalf("expected mode to remain %s, got %s", tt.mode, m.Mode) + } + }) + } +} + +func TestAdaptorGetRequestURLUnsupportedResponses(t *testing.T) { + adaptor := &Adaptor{} + m := meta.NewMeta( + &coremodel.Channel{BaseURL: "https://open.bigmodel.cn/api/paas/v4"}, + mode.Responses, + "glm-5.1", + coremodel.ModelConfig{}, + ) + + if _, err := adaptor.GetRequestURL(m, nil, nil); err == nil { + t.Fatal("expected Responses mode to be unsupported") + } +} diff --git a/core/relay/adaptor/zhipucoding/adaptor.go b/core/relay/adaptor/zhipucoding/adaptor.go index 613c1ac6..382ac01d 100644 --- a/core/relay/adaptor/zhipucoding/adaptor.go +++ b/core/relay/adaptor/zhipucoding/adaptor.go @@ -1,6 +1,7 @@ package zhipucoding import ( + "fmt" "net/http" "net/url" @@ -14,6 +15,7 @@ import ( "github.com/labring/aiproxy/core/relay/adaptor/zhipu" "github.com/labring/aiproxy/core/relay/meta" "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" "github.com/labring/aiproxy/core/relay/utils" ) @@ -36,7 +38,8 @@ func (a *Adaptor) DefaultBaseURL() string { func (a *Adaptor) SupportMode(m mode.Mode) bool { return m == mode.ChatCompletions || m == mode.Completions || - m == mode.Anthropic + m == mode.Anthropic || + m == mode.Gemini } func (a *Adaptor) GetRequestURL( @@ -44,32 +47,45 @@ func (a *Adaptor) GetRequestURL( store adaptor.Store, c *gin.Context, ) (adaptor.RequestURL, error) { - u := meta.Channel.BaseURL + originalBaseURL := meta.Channel.BaseURL switch meta.Mode { case mode.Anthropic: - url, err := url.JoinPath(u, "/api/anthropic/v1/messages") + u, err := url.JoinPath(originalBaseURL, "/api/anthropic/v1/messages") if err != nil { return adaptor.RequestURL{}, err } return adaptor.RequestURL{ Method: http.MethodPost, - URL: url, + URL: u, }, nil - default: - meta.Channel.BaseURL += "/api/coding/paas/v4" + case mode.ChatCompletions, mode.Completions, mode.Gemini: + u, err := url.JoinPath(originalBaseURL, "/api/coding/paas/v4") + if err != nil { + return adaptor.RequestURL{}, err + } + + originalMode := meta.Mode + if meta.Mode == mode.Gemini { + meta.Mode = mode.ChatCompletions + } + + meta.Channel.BaseURL = u defer func() { - meta.Channel.BaseURL = u + meta.Mode = originalMode + meta.Channel.BaseURL = originalBaseURL }() return a.Adaptor.GetRequestURL(meta, store, c) + default: + return adaptor.RequestURL{}, fmt.Errorf("unsupported mode: %s", meta.Mode) } } func (a *Adaptor) ConvertRequest( meta *meta.Meta, - store adaptor.Store, + _ adaptor.Store, req *http.Request, ) (adaptor.ConvertResult, error) { switch meta.Mode { @@ -82,8 +98,14 @@ func (a *Adaptor) ConvertRequest( return nil }) + case mode.ChatCompletions: + return openai.ConvertChatCompletionsRequest(meta, req, false) + case mode.Completions: + return openai.ConvertCompletionsRequest(meta, req) + case mode.Gemini: + return openai.ConvertGeminiRequest(meta, req) default: - return a.Adaptor.ConvertRequest(meta, store, req) + return adaptor.ConvertResult{}, fmt.Errorf("unsupported mode: %s", meta.Mode) } } @@ -93,20 +115,35 @@ func (a *Adaptor) DoResponse( c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, zhipu.ErrorHandler(resp) + } + switch meta.Mode { case mode.Anthropic: if utils.IsStreamResponse(resp) { return anthropic.StreamHandler(meta, c, resp) } return anthropic.Handler(meta, c, resp) - default: + case mode.Gemini: + if utils.IsStreamResponse(resp) { + return openai.GeminiStreamHandler(meta, c, resp) + } + return openai.GeminiHandler(meta, c, resp) + case mode.ChatCompletions, mode.Completions: return a.Adaptor.DoResponse(meta, store, c, resp) + default: + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIErrorWithMessage( + fmt.Sprintf("unsupported mode: %s", meta.Mode), + "unsupported_mode", + http.StatusBadRequest, + ) } } func (a *Adaptor) Metadata() adaptor.Metadata { return adaptor.Metadata{ - Readme: "Zhipu Coding endpoint\nChat and completions are routed to `/api/coding/paas/v4`\nAnthropic-compatible requests are routed to `/api/anthropic/v1/messages`", + Readme: "Zhipu Coding endpoint\nChat and completions are routed to `/api/coding/paas/v4`\nAnthropic-compatible requests are routed to `/api/anthropic/v1/messages`\nGemini-compatible requests are converted to chat completions", Models: zhipu.ModelList, } } diff --git a/core/relay/adaptor/zhipucoding/adaptor_test.go b/core/relay/adaptor/zhipucoding/adaptor_test.go new file mode 100644 index 00000000..e14ed9b1 --- /dev/null +++ b/core/relay/adaptor/zhipucoding/adaptor_test.go @@ -0,0 +1,146 @@ +//nolint:testpackage +package zhipucoding + +import ( + "io" + "net/http" + "strings" + "testing" + + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" +) + +func TestAdaptorSupportMode(t *testing.T) { + adaptor := &Adaptor{} + + supportedModes := []mode.Mode{ + mode.ChatCompletions, + mode.Completions, + mode.Anthropic, + mode.Gemini, + } + for _, m := range supportedModes { + if !adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be supported", m) + } + } + + unsupportedModes := []mode.Mode{ + mode.Responses, + mode.Embeddings, + mode.AudioSpeech, + mode.Rerank, + } + for _, m := range unsupportedModes { + if adaptor.SupportMode(m) { + t.Fatalf("expected mode %s to be unsupported", m) + } + } +} + +func TestAdaptorGetRequestURL(t *testing.T) { + adaptor := &Adaptor{} + channel := &coremodel.Channel{ + BaseURL: "https://open.bigmodel.cn", + } + + tests := []struct { + name string + mode mode.Mode + want string + }{ + { + name: "anthropic uses native anthropic endpoint", + mode: mode.Anthropic, + want: "https://open.bigmodel.cn/api/anthropic/v1/messages", + }, + { + name: "gemini uses coding chat completions", + mode: mode.Gemini, + want: "https://open.bigmodel.cn/api/coding/paas/v4/chat/completions", + }, + { + name: "chat uses coding chat completions", + mode: mode.ChatCompletions, + want: "https://open.bigmodel.cn/api/coding/paas/v4/chat/completions", + }, + { + name: "completions uses coding completions", + mode: mode.Completions, + want: "https://open.bigmodel.cn/api/coding/paas/v4/completions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := meta.NewMeta(channel, tt.mode, "glm-5.1", coremodel.ModelConfig{}) + + got, err := adaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.Method != http.MethodPost { + t.Fatalf("expected method %s, got %s", http.MethodPost, got.Method) + } + + if got.URL != tt.want { + t.Fatalf("expected URL %s, got %s", tt.want, got.URL) + } + + if m.Mode != tt.mode { + t.Fatalf("expected mode to remain %s, got %s", tt.mode, m.Mode) + } + + if m.Channel.BaseURL != channel.BaseURL { + t.Fatalf( + "expected base URL to remain %s, got %s", + channel.BaseURL, + m.Channel.BaseURL, + ) + } + }) + } +} + +func TestAdaptorGetRequestURLUnsupportedResponses(t *testing.T) { + adaptor := &Adaptor{} + m := meta.NewMeta( + &coremodel.Channel{BaseURL: "https://open.bigmodel.cn"}, + mode.Responses, + "glm-5.1", + coremodel.ModelConfig{}, + ) + + if _, err := adaptor.GetRequestURL(m, nil, nil); err == nil { + t.Fatal("expected Responses mode to be unsupported") + } +} + +func TestAdaptorDoResponseUsesZhipuErrorHandler(t *testing.T) { + adaptor := &Adaptor{} + m := meta.NewMeta( + &coremodel.Channel{BaseURL: "https://open.bigmodel.cn"}, + mode.ChatCompletions, + "glm-5.1", + coremodel.ModelConfig{}, + ) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser( + //nolint:lll + strings.NewReader(`{"error":{"code":"1113","message":"余额不足或无可用资源包,请充值。"}}`), + ), + } + + _, err := adaptor.DoResponse(m, nil, nil, resp) + if err == nil { + t.Fatal("expected zhipu error") + } + + if err.StatusCode() != http.StatusPaymentRequired { + t.Fatalf("expected status %d, got %d", http.StatusPaymentRequired, err.StatusCode()) + } +}