From 4ef2071f1ed46b79f871b3906a3560d710fa4b2c Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Fri, 22 May 2026 01:18:08 +0800 Subject: [PATCH 1/2] feat: siliconflow support audio image video --- core/relay/adaptor/openai/gemini.go | 116 +- core/relay/adaptor/siliconflow/adaptor.go | 120 +- .../relay/adaptor/siliconflow/adaptor_test.go | 1047 +++++++++++++++++ core/relay/adaptor/siliconflow/async_usage.go | 150 +++ core/relay/adaptor/siliconflow/chat.go | 226 ++++ core/relay/adaptor/siliconflow/chat_test.go | 171 +++ core/relay/adaptor/siliconflow/image.go | 124 +- core/relay/adaptor/siliconflow/video.go | 814 +++++++++++++ 8 files changed, 2744 insertions(+), 24 deletions(-) create mode 100644 core/relay/adaptor/siliconflow/adaptor_test.go create mode 100644 core/relay/adaptor/siliconflow/async_usage.go create mode 100644 core/relay/adaptor/siliconflow/chat.go create mode 100644 core/relay/adaptor/siliconflow/chat_test.go create mode 100644 core/relay/adaptor/siliconflow/video.go diff --git a/core/relay/adaptor/openai/gemini.go b/core/relay/adaptor/openai/gemini.go index 392cf38b..70dd8f38 100644 --- a/core/relay/adaptor/openai/gemini.go +++ b/core/relay/adaptor/openai/gemini.go @@ -3,7 +3,10 @@ package openai import ( "bytes" "fmt" + "mime" "net/http" + "net/url" + "path/filepath" "sort" "strconv" "strings" @@ -685,19 +688,16 @@ func convertGeminiContentToOpenAI( hasContent = true case part.InlineData != nil: - // Handle image - imageURL := part.InlineData.Data - if !strings.HasPrefix(imageURL, "http") && !strings.HasPrefix(imageURL, "data:") { - // Base64 data - imageURL = "data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data - } - - currentContentParts = append(currentContentParts, relaymodel.MessageContent{ - Type: relaymodel.ContentTypeImageURL, - ImageURL: &relaymodel.ImageURL{ - URL: imageURL, - }, - }) + currentContentParts = append( + currentContentParts, + convertGeminiInlineDataToOpenAIContent(part.InlineData), + ) + hasContent = true + case part.FileData != nil: + currentContentParts = append( + currentContentParts, + convertGeminiFileDataToOpenAIContent(part.FileData), + ) hasContent = true } } @@ -720,6 +720,96 @@ func convertGeminiContentToOpenAI( return messages } +func convertGeminiInlineDataToOpenAIContent( + inlineData *relaymodel.GeminiInlineData, +) relaymodel.MessageContent { + dataURL := inlineData.Data + if !strings.HasPrefix(dataURL, "http") && !strings.HasPrefix(dataURL, "data:") { + dataURL = "data:" + inlineData.MimeType + ";base64," + inlineData.Data + } + + switch { + case strings.HasPrefix(inlineData.MimeType, "audio/"): + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeInputAudio, + InputAudio: &relaymodel.InputAudio{ + URL: dataURL, + }, + } + case strings.HasPrefix(inlineData.MimeType, "video/"): + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeVideoURL, + VideoURL: &relaymodel.VideoURL{ + URL: dataURL, + }, + } + default: + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeImageURL, + ImageURL: &relaymodel.ImageURL{ + URL: dataURL, + }, + } + } +} + +func convertGeminiFileDataToOpenAIContent( + fileData *relaymodel.GeminiFileData, +) relaymodel.MessageContent { + mimeType := fileData.MimeType + if mimeType == "" { + mimeType = inferGeminiFileDataMimeType(fileData.FileURI) + } + + switch { + case strings.HasPrefix(mimeType, "audio/"): + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeInputAudio, + InputAudio: &relaymodel.InputAudio{ + URL: fileData.FileURI, + }, + } + case strings.HasPrefix(mimeType, "video/"): + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeVideoURL, + VideoURL: &relaymodel.VideoURL{ + URL: fileData.FileURI, + }, + } + default: + return relaymodel.MessageContent{ + Type: relaymodel.ContentTypeImageURL, + ImageURL: &relaymodel.ImageURL{ + URL: fileData.FileURI, + }, + } + } +} + +func inferGeminiFileDataMimeType(fileURI string) string { + if after, ok := strings.CutPrefix(fileURI, "data:"); ok { + mediaType := after + if beforeParams, _, ok := strings.Cut(mediaType, ";"); ok { + return beforeParams + } + + if beforeData, _, ok := strings.Cut(mediaType, ","); ok { + return beforeData + } + } + + path := fileURI + if parsed, err := url.Parse(fileURI); err == nil && parsed.Path != "" { + path = parsed.Path + } + + if ext := filepath.Ext(path); ext != "" { + return mime.TypeByExtension(strings.ToLower(ext)) + } + + return "" +} + // ConvertGeminiToResponsesRequest converts a Gemini request to Responses API format func ConvertGeminiToResponsesRequest( meta *meta.Meta, diff --git a/core/relay/adaptor/siliconflow/adaptor.go b/core/relay/adaptor/siliconflow/adaptor.go index aac86fc3..48fd672e 100644 --- a/core/relay/adaptor/siliconflow/adaptor.go +++ b/core/relay/adaptor/siliconflow/adaptor.go @@ -1,7 +1,9 @@ package siliconflow import ( + "fmt" "net/http" + "net/url" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/model" @@ -10,6 +12,7 @@ import ( "github.com/labring/aiproxy/core/relay/adaptor/registry" "github.com/labring/aiproxy/core/relay/meta" "github.com/labring/aiproxy/core/relay/mode" + "github.com/labring/aiproxy/core/relay/utils" ) var _ adaptor.Adaptor = (*Adaptor)(nil) @@ -30,12 +33,74 @@ func (a *Adaptor) DefaultBaseURL() string { func (a *Adaptor) Metadata() adaptor.Metadata { return adaptor.Metadata{ - Readme: "SiliconFlow API\nOpenAI-compatible chat, embeddings, audio, and rerank endpoints\nSupports Gemini-compatible request conversion", + Readme: "SiliconFlow API\nOpenAI-compatible chat, embeddings, image, audio, rerank, and video endpoints\nChat supports audio/video understanding request conversion", Models: ModelList, } } -// +func (a *Adaptor) SupportMode(meta *meta.Meta) bool { + m := adaptor.ModeFromMeta(meta) + + return m == mode.ChatCompletions || + m == mode.Completions || + m == mode.Embeddings || + m == mode.ImagesGenerations || + m == mode.AudioSpeech || + m == mode.AudioTranscription || + m == mode.Rerank || + m == mode.VideoGenerationsJobs || + m == mode.VideoGenerationsGetJobs || + m == mode.VideoGenerationsContent || + m == mode.Videos || + m == mode.VideosGet || + m == mode.VideosContent || + m == mode.Anthropic || + m == mode.Gemini +} + +func (a *Adaptor) GetRequestURL( + meta *meta.Meta, + _ adaptor.Store, + _ *gin.Context, +) (adaptor.RequestURL, error) { + u := meta.Channel.BaseURL + + var path string + + switch meta.Mode { + case mode.ChatCompletions, mode.Anthropic, mode.Gemini: + path = "/chat/completions" + case mode.Completions: + path = "/completions" + case mode.Embeddings: + path = "/embeddings" + case mode.ImagesGenerations: + path = "/images/generations" + case mode.AudioSpeech: + path = "/audio/speech" + case mode.AudioTranscription: + path = "/audio/transcriptions" + case mode.Rerank: + path = "/rerank" + case mode.VideoGenerationsJobs, mode.Videos: + path = "/video/submit" + case mode.VideoGenerationsGetJobs, mode.VideoGenerationsContent, + mode.VideosGet, mode.VideosContent: + path = "/video/status" + default: + return adaptor.RequestURL{}, fmt.Errorf("unsupported mode: %s", meta.Mode) + } + + requestURL, err := url.JoinPath(u, path) + if err != nil { + return adaptor.RequestURL{}, err + } + + return adaptor.RequestURL{ + Method: http.MethodPost, + URL: requestURL, + }, nil +} func (a *Adaptor) ConvertRequest( meta *meta.Meta, @@ -43,12 +108,37 @@ func (a *Adaptor) ConvertRequest( req *http.Request, ) (adaptor.ConvertResult, error) { switch meta.Mode { + case mode.ChatCompletions: + return openai.ConvertChatCompletionsRequest( + meta, + req, + false, + patchChatMultimodalContent, + ) case mode.Embeddings: if isVLEmbeddingModel(meta) { return openai.ConvertEmbeddingsRequest(meta, req, false, patchVLEmbeddingsInput) } return a.Adaptor.ConvertRequest(meta, store, req) + case mode.ImagesGenerations: + return ConvertImageRequest(meta, req) + case mode.VideoGenerationsJobs: + return ConvertVideoRequest(meta, req) + case mode.Videos: + return ConvertVideoRequest(meta, req) + case mode.VideoGenerationsGetJobs: + return ConvertVideoStatusRequest(meta, req) + case mode.VideoGenerationsContent: + return ConvertVideoContentStatusRequest(meta, req) + case mode.VideosGet: + return ConvertVideosStatusRequest(meta, req) + case mode.VideosContent: + return ConvertVideosStatusRequest(meta, req) + case mode.Anthropic: + return openai.ConvertClaudeRequest(meta, req, patchSiliconFlowMultimodalContent) + case mode.Gemini: + return openai.ConvertGeminiRequest(meta, req, patchSiliconFlowMultimodalContent) default: return a.Adaptor.ConvertRequest(meta, store, req) } @@ -61,6 +151,24 @@ func (a *Adaptor) DoResponse( resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { switch meta.Mode { + case mode.ChatCompletions: + if utils.IsStreamResponse(resp) { + return openai.StreamHandler(meta, c, resp, nil) + } + + return openai.Handler(meta, c, resp, nil) + case mode.Anthropic: + if utils.IsStreamResponse(resp) { + return openai.ClaudeStreamHandler(meta, c, resp) + } + + return openai.ClaudeHandler(meta, c, resp) + case mode.Gemini: + if utils.IsStreamResponse(resp) { + return openai.GeminiStreamHandler(meta, c, resp) + } + + return openai.GeminiHandler(meta, c, resp) case mode.AudioSpeech: if resp.StatusCode != http.StatusOK { return adaptor.DoResponseResult{}, ErrorHandler(resp) @@ -84,6 +192,14 @@ func (a *Adaptor) DoResponse( } return a.Adaptor.DoResponse(meta, store, c, resp) + case mode.ImagesGenerations: + return ImageHandler(meta, c, resp) + case mode.VideoGenerationsJobs, mode.Videos: + return VideoSubmitHandler(meta, store, c, resp) + case mode.VideoGenerationsGetJobs, mode.VideosGet: + return VideoStatusHandler(meta, store, c, resp) + case mode.VideoGenerationsContent, mode.VideosContent: + return VideoContentHandler(meta, c, resp) default: if !adaptor.IsSuccessfulResponseStatus(meta.Mode, resp.StatusCode) { return adaptor.DoResponseResult{}, ErrorHandler(resp) diff --git a/core/relay/adaptor/siliconflow/adaptor_test.go b/core/relay/adaptor/siliconflow/adaptor_test.go new file mode 100644 index 00000000..3ce7fc69 --- /dev/null +++ b/core/relay/adaptor/siliconflow/adaptor_test.go @@ -0,0 +1,1047 @@ +//nolint:testpackage +package siliconflow + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" +) + +type siliconflowTestStore struct { + saved []adaptor.StoreCache +} + +func (s *siliconflowTestStore) GetStore(string, int, string) (adaptor.StoreCache, error) { + return adaptor.StoreCache{}, nil +} + +func (s *siliconflowTestStore) SaveStore(cache adaptor.StoreCache) error { + s.saved = append(s.saved, cache) + return nil +} + +func (s *siliconflowTestStore) SaveStoreWithOption( + cache adaptor.StoreCache, + _ adaptor.SaveStoreOption, +) error { + s.saved = append(s.saved, cache) + return nil +} + +func (s *siliconflowTestStore) SaveIfNotExistStore(cache adaptor.StoreCache) error { + s.saved = append(s.saved, cache) + return nil +} + +func TestAdaptorSupportModeDoesNotSupportResponses(t *testing.T) { + sfAdaptor := &Adaptor{} + + supportedModes := []mode.Mode{ + mode.ChatCompletions, + mode.Completions, + mode.Embeddings, + mode.ImagesGenerations, + mode.AudioSpeech, + mode.AudioTranscription, + mode.Rerank, + mode.VideoGenerationsJobs, + mode.VideoGenerationsGetJobs, + mode.VideoGenerationsContent, + mode.Videos, + mode.VideosGet, + mode.VideosContent, + mode.Anthropic, + mode.Gemini, + } + for _, m := range supportedModes { + if !sfAdaptor.SupportMode(&meta.Meta{Mode: m}) { + t.Fatalf("expected mode %s to be supported", m) + } + } + + unsupportedModes := []mode.Mode{ + mode.Responses, + mode.ResponsesGet, + mode.ResponsesDelete, + mode.ResponsesCancel, + mode.ResponsesInputItems, + mode.ImagesEdits, + mode.VideosDelete, + mode.VideosRemix, + } + for _, m := range unsupportedModes { + if sfAdaptor.SupportMode(&meta.Meta{Mode: m}) { + t.Fatalf("expected mode %s to be unsupported", m) + } + } +} + +func TestAdaptorGetRequestURLUsesSiliconFlowEndpoints(t *testing.T) { + sfAdaptor := &Adaptor{} + + tests := []struct { + name string + mode mode.Mode + wantMethod string + wantURL string + }{ + { + name: "chat", + mode: mode.ChatCompletions, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/chat/completions", + }, + { + name: "image generation", + mode: mode.ImagesGenerations, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/images/generations", + }, + { + name: "video submit", + mode: mode.VideoGenerationsJobs, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/submit", + }, + { + name: "video status", + mode: mode.VideoGenerationsGetJobs, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/status", + }, + { + name: "video content", + mode: mode.VideoGenerationsContent, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/status", + }, + { + name: "videos create", + mode: mode.Videos, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/submit", + }, + { + name: "videos get", + mode: mode.VideosGet, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/status", + }, + { + name: "videos content", + mode: mode.VideosContent, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/video/status", + }, + { + name: "anthropic", + mode: mode.Anthropic, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/chat/completions", + }, + { + name: "gemini", + mode: mode.Gemini, + wantMethod: http.MethodPost, + wantURL: "https://api.siliconflow.cn/v1/chat/completions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := meta.NewMeta( + &coremodel.Channel{BaseURL: "https://api.siliconflow.cn/v1"}, + tt.mode, + "test-model", + coremodel.ModelConfig{}, + ) + + got, err := sfAdaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.Method != tt.wantMethod || got.URL != tt.wantURL { + t.Fatalf("unexpected request URL: %#v", got) + } + }) + } +} + +func TestAdaptorGetRequestURLRejectsResponsesOnlyRouting(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta( + &coremodel.Channel{BaseURL: "https://api.siliconflow.cn/v1"}, + mode.ChatCompletions, + "gpt-5-codex", + coremodel.ModelConfig{}, + ) + + got, err := sfAdaptor.GetRequestURL(m, nil, nil) + if err != nil { + t.Fatalf("GetRequestURL returned error: %v", err) + } + + if got.URL != "https://api.siliconflow.cn/v1/chat/completions" { + t.Fatalf("expected chat endpoint, got %q", got.URL) + } + + m.Mode = mode.Responses + if _, err := sfAdaptor.GetRequestURL(m, nil, nil); err == nil { + t.Fatal("expected responses mode to be rejected") + } +} + +func TestAdaptorDoResponseChatDoesNotUseResponsesOnlyConversion(t *testing.T) { + gin.SetMode(gin.TestMode) + + sfAdaptor := &Adaptor{} + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := meta.NewMeta(nil, mode.ChatCompletions, "gpt-5-codex", coremodel.ModelConfig{}) + m.RequestUsage = coremodel.Usage{InputTokens: 3} + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "id":"chatcmpl-123", + "object":"chat.completion", + "choices":[ + { + "index":0, + "message":{"role":"assistant","content":"Done"}, + "finish_reason":"stop" + } + ], + "usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4} + }`))), + } + + result, adaptorErr := sfAdaptor.DoResponse(m, nil, ctx, resp) + if adaptorErr != nil { + t.Fatalf("DoResponse returned error: %v", adaptorErr) + } + + if result.UpstreamID != "chatcmpl-123" { + t.Fatalf("expected chat upstream id, got %q", result.UpstreamID) + } + + var body map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if body["object"] != "chat.completion" { + t.Fatalf("expected chat completion response, got %#v", body["object"]) + } +} + +func TestConvertAnthropicRequestUsesChatCompletionsShape(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.Anthropic, "claude-alias", coremodel.ModelConfig{}) + m.ActualModel = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + req := newJSONRequest(t, "/v1/messages", `{ + "model":"claude-alias", + "messages":[ + { + "role":"user", + "content":[ + {"type":"text","text":"Describe this image."}, + { + "type":"image", + "source":{ + "type":"url", + "url":"https://example.com/image.png" + } + } + ] + } + ], + "max_tokens":128 + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "model", "Qwen/Qwen3-Omni-30B-A3B-Instruct") + assertMapNumber(t, got, "max_tokens", 128) + + messages, ok := got["messages"].([]any) + if !ok || len(messages) != 1 { + t.Fatalf("expected one message, got %#v", got["messages"]) + } + + message, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("expected message object, got %#v", messages[0]) + } + + content, ok := message["content"].([]any) + if !ok || len(content) != 2 { + t.Fatalf("expected two content parts, got %#v", message["content"]) + } + + assertSiliconFlowContentType(t, content[0], "text") + assertSiliconFlowContentType(t, content[1], "image_url") +} + +func TestConvertGeminiRequestUsesChatCompletionsShape(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.Gemini, "gemini-alias", coremodel.ModelConfig{}) + m.ActualModel = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + req := newJSONRequest(t, "/v1beta/models/gemini-alias:generateContent", `{ + "contents":[ + { + "role":"user", + "parts":[ + {"text":"Describe this media."}, + {"inlineData":{"mimeType":"audio/wav","data":"QUJD"}} + ] + } + ], + "generationConfig":{"maxOutputTokens":128} + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "model", "Qwen/Qwen3-Omni-30B-A3B-Instruct") + + messages, ok := got["messages"].([]any) + if !ok || len(messages) != 1 { + t.Fatalf("expected one message, got %#v", got["messages"]) + } + + message, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("expected message object, got %#v", messages[0]) + } + + content, ok := message["content"].([]any) + if !ok || len(content) != 2 { + t.Fatalf("expected two content parts, got %#v", message["content"]) + } + + assertSiliconFlowContentType(t, content[0], "text") + assertSiliconFlowContentType(t, content[1], "audio_url") +} + +func TestConvertGeminiRequestInfersFileDataMediaTypeFromURI(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.Gemini, "gemini-alias", coremodel.ModelConfig{}) + m.ActualModel = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + req := newJSONRequest(t, "/v1beta/models/gemini-alias:generateContent", `{ + "contents":[ + { + "role":"user", + "parts":[ + {"fileData":{"fileUri":"https://example.com/audio.mp3"}}, + {"fileData":{"fileUri":"https://example.com/video.mp4?token=abc"}} + ] + } + ] + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + + messages, ok := got["messages"].([]any) + if !ok || len(messages) != 1 { + t.Fatalf("expected one message, got %#v", got["messages"]) + } + + message, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("expected message object, got %#v", messages[0]) + } + + content, ok := message["content"].([]any) + if !ok || len(content) != 2 { + t.Fatalf("expected two content parts, got %#v", message["content"]) + } + + assertSiliconFlowContentType(t, content[0], "audio_url") + assertSiliconFlowContentType(t, content[1], "video_url") +} + +func TestConvertImageRequestMapsOpenAIFields(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.ImagesGenerations, "alias-image", coremodel.ModelConfig{}) + m.ActualModel = "stabilityai/stable-diffusion-3-5-large" + + req := newJSONRequest(t, "/v1/images/generations", `{ + "model":"alias-image", + "prompt":"A city at sunset", + "negative_prompt":"low quality", + "size":"1024x1024", + "n":2, + "steps":24, + "scale":7, + "response_format":"url" + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "model", "stabilityai/stable-diffusion-3-5-large") + assertMapValue(t, got, "prompt", "A city at sunset") + assertMapValue(t, got, "negative_prompt", "low quality") + assertMapValue(t, got, "image_size", "1024x1024") + assertMapNumber(t, got, "batch_size", 2) + assertMapNumber(t, got, "num_inference_steps", 24) + assertMapNumber(t, got, "guidance_scale", 7) + + if _, ok := got["size"]; ok { + t.Fatal("expected size to be removed") + } + + if _, ok := got["n"]; ok { + t.Fatal("expected n to be removed") + } +} + +func TestImageHandlerMapsSiliconFlowResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/images/generations", + nil, + ) + + m := meta.NewMeta( + nil, + mode.ImagesGenerations, + "stabilityai/stable-diffusion-3-5-large", + coremodel.ModelConfig{}, + meta.WithRequestUsage(coremodel.Usage{ + InputTokens: 12, + OutputTokens: 2, + }), + ) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "images":[ + {"url":"https://example.com/one.png"}, + {"url":"https://example.com/two.png"} + ], + "seed":123 + }`))), + } + + result, adaptorErr := ImageHandler(m, ctx, resp) + if adaptorErr != nil { + t.Fatalf("ImageHandler returned error: %v", adaptorErr) + } + + var imageResponse relaymodel.ImageResponse + if err := json.Unmarshal(recorder.Body.Bytes(), &imageResponse); err != nil { + t.Fatalf("failed to unmarshal image response: %v", err) + } + + if len(imageResponse.Data) != 2 { + t.Fatalf("expected 2 images, got %d", len(imageResponse.Data)) + } + + if imageResponse.Data[0].URL != "https://example.com/one.png" || + imageResponse.Data[1].URL != "https://example.com/two.png" { + t.Fatalf("unexpected image response data: %#v", imageResponse.Data) + } + + if int64(result.Usage.InputTokens) != 12 || + int64(result.Usage.OutputTokens) != 2 || + int64(result.Usage.TotalTokens) != 14 { + t.Fatalf("unexpected usage: %#v", result.Usage) + } +} + +func TestConvertVideoRequestMapsJSONFields(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.VideoGenerationsJobs, "alias-video", coremodel.ModelConfig{}) + m.ActualModel = "Wan-AI/Wan2.2-T2V-A14B" + + req := newJSONRequest(t, "/v1/video/generations/jobs", `{ + "model":"alias-video", + "prompt":"A calm ocean", + "size":"1280x720", + "negative_prompt":"rain", + "seed":123 + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "model", "Wan-AI/Wan2.2-T2V-A14B") + assertMapValue(t, got, "prompt", "A calm ocean") + assertMapValue(t, got, "image_size", "1280x720") + assertMapValue(t, got, "negative_prompt", "rain") + assertMapNumber(t, got, "seed", 123) +} + +func TestConvertVideoRequestMapsImageReference(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta(nil, mode.VideoGenerationsJobs, "alias-video", coremodel.ModelConfig{}) + m.ActualModel = "Wan-AI/Wan2.2-I2V-A14B" + + req := newJSONRequest(t, "/v1/video/generations/jobs", `{ + "model":"alias-video", + "prompt":"Animate this scene", + "size":"960x960", + "input_reference":"https://example.com/input.png" + }`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "image", "https://example.com/input.png") +} + +func TestConvertVideoStatusRequestUsesJobID(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta( + nil, + mode.VideoGenerationsGetJobs, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithJobID("request-123"), + ) + + req := newJSONRequest(t, "/v1/video/generations/jobs/request-123", `{}`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "requestId", "request-123") +} + +func TestConvertVideoContentStatusRequestUsesGenerationID(t *testing.T) { + sfAdaptor := &Adaptor{} + m := meta.NewMeta( + nil, + mode.VideoGenerationsContent, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithGenerationID("request-123"), + ) + + req := newJSONRequest(t, "/v1/video/generations/request-123/content/video", `{}`) + + result, err := sfAdaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readJSONBody(t, result.Body) + assertMapValue(t, got, "requestId", "request-123") +} + +func TestVideoSubmitHandlerMapsRequestIDToJob(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/video/generations/jobs", + nil, + ) + + m := meta.NewMeta( + nil, + mode.VideoGenerationsJobs, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + ) + m.Channel.ID = 9 + m.Group = coremodel.GroupCache{ID: "group-1"} + m.Token = coremodel.TokenCache{ID: 7} + m.Set(metaVideoRequest, videoSubmitRequest{ + Prompt: "A calm ocean", + ImageSize: "1280x720", + }) + + store := &siliconflowTestStore{} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{"requestId":"request-123"}`))), + } + + result, adaptorErr := VideoSubmitHandler(m, store, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoSubmitHandler returned error: %v", adaptorErr) + } + + if result.UpstreamID != "request-123" || !result.AsyncUsage { + t.Fatalf("unexpected result: %#v", result) + } + + if m.RequestUsage.OutputTokens != 1 || m.RequestUsage.TotalTokens != 1 { + t.Fatalf("expected submitted video usage to count videos, got %#v", m.RequestUsage) + } + + if len(store.saved) != 1 || store.saved[0].ID != coremodel.VideoJobStoreID("request-123") { + t.Fatalf("unexpected saved stores: %#v", store.saved) + } + + var job relaymodel.VideoGenerationJob + if err := json.Unmarshal(recorder.Body.Bytes(), &job); err != nil { + t.Fatalf("failed to unmarshal job: %v", err) + } + + if job.ID != "request-123" || + job.Status != relaymodel.VideoGenerationJobStatusQueued || + job.Prompt != "A calm ocean" || + job.Width != 1280 || + job.Height != 720 { + t.Fatalf("unexpected job: %#v", job) + } +} + +func TestVideoStatusHandlerMapsSucceededResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/video/generations/jobs/request-123", + nil, + ) + + m := meta.NewMeta( + nil, + mode.VideoGenerationsGetJobs, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithJobID("request-123"), + ) + m.Channel.ID = 9 + m.Group = coremodel.GroupCache{ID: "group-1"} + m.Token = coremodel.TokenCache{ID: 7} + m.Set(metaVideoRequest, videoSubmitRequest{ + Prompt: "A calm ocean", + ImageSize: "1280x720", + }) + + store := &siliconflowTestStore{} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "status":"Succeed", + "results":{"videos":[{"url":"https://example.com/video.mp4"}]} + }`))), + } + + _, adaptorErr := VideoStatusHandler(m, store, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoStatusHandler returned error: %v", adaptorErr) + } + + var job relaymodel.VideoGenerationJob + if err := json.Unmarshal(recorder.Body.Bytes(), &job); err != nil { + t.Fatalf("failed to unmarshal job: %v", err) + } + + if job.ID != "request-123" || + job.Status != relaymodel.VideoGenerationJobStatusSucceeded || + len(job.Generations) != 1 { + t.Fatalf("unexpected job: %#v", job) + } + + if len(store.saved) != 1 || + store.saved[0].ID != coremodel.VideoGenerationStoreID(job.Generations[0].ID) { + t.Fatalf("unexpected saved stores: %#v", store.saved) + } + + if job.Generations[0].ID != "request-123" { + t.Fatalf("expected generation id to use upstream request id, got %q", job.Generations[0].ID) + } +} + +func TestVideosHandlerMapsSubmitResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/videos", + nil, + ) + + m := meta.NewMeta(nil, mode.Videos, "Wan-AI/Wan2.2-T2V-A14B", coremodel.ModelConfig{}) + m.Channel.ID = 9 + m.Group = coremodel.GroupCache{ID: "group-1"} + m.Token = coremodel.TokenCache{ID: 7} + m.Set(metaVideoRequest, videoSubmitRequest{ + Prompt: "A calm ocean", + ImageSize: "1280x720", + }) + + store := &siliconflowTestStore{} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{"requestId":"request-123"}`))), + } + + result, adaptorErr := VideoSubmitHandler(m, store, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoSubmitHandler returned error: %v", adaptorErr) + } + + if result.UpstreamID != "request-123" || !result.AsyncUsage { + t.Fatalf("unexpected result: %#v", result) + } + + if len(store.saved) != 1 || + store.saved[0].ID != coremodel.VideoGenerationStoreID("request-123") { + t.Fatalf("unexpected saved stores: %#v", store.saved) + } + + var video relaymodel.Video + if err := json.Unmarshal(recorder.Body.Bytes(), &video); err != nil { + t.Fatalf("failed to unmarshal video: %v", err) + } + + if video.ID != "request-123" || + video.Object != relaymodel.VideoObject || + video.Status != relaymodel.VideoStatusQueued || + video.Prompt != "A calm ocean" || + video.Size != "1280x720" { + t.Fatalf("unexpected video: %#v", video) + } +} + +func TestVideosGetHandlerMapsStatusResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/videos/request-123", + nil, + ) + + m := meta.NewMeta( + nil, + mode.VideosGet, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithVideoID("request-123"), + ) + m.Channel.ID = 9 + m.Group = coremodel.GroupCache{ID: "group-1"} + m.Token = coremodel.TokenCache{ID: 7} + m.Set(metaVideoRequest, videoSubmitRequest{ + Prompt: "A calm ocean", + ImageSize: "1280x720", + }) + + store := &siliconflowTestStore{} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "status":"Succeed", + "results":{"videos":[{"url":"https://example.com/video.mp4"}]} + }`))), + } + + result, adaptorErr := VideoStatusHandler(m, store, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoStatusHandler returned error: %v", adaptorErr) + } + + if result.UpstreamID != "request-123" { + t.Fatalf("expected upstream id, got %q", result.UpstreamID) + } + + var video relaymodel.Video + if err := json.Unmarshal(recorder.Body.Bytes(), &video); err != nil { + t.Fatalf("failed to unmarshal video: %v", err) + } + + if video.ID != "request-123" || + video.Object != relaymodel.VideoObject || + video.Status != relaymodel.VideoStatusCompleted || + video.Progress != 100 { + t.Fatalf("unexpected video: %#v", video) + } +} + +func TestVideoContentHandlerDownloadsGeneratedVideo(t *testing.T) { + gin.SetMode(gin.TestMode) + + videoServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + _, _ = w.Write([]byte("video-data")) + }), + ) + defer videoServer.Close() + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodGet, + "/v1/video/generations/request-123/content/video", + nil, + ) + + m := meta.NewMeta( + &coremodel.Channel{}, + mode.VideoGenerationsContent, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithGenerationID("request-123"), + ) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "status":"Succeed", + "results":{"videos":[{"url":"` + videoServer.URL + `"}]} + }`))), + } + + result, adaptorErr := VideoContentHandler(m, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoContentHandler returned error: %v", adaptorErr) + } + + if result.UpstreamID != "request-123" { + t.Fatalf("expected upstream id, got %q", result.UpstreamID) + } + + if recorder.Body.String() != "video-data" { + t.Fatalf("unexpected video body: %q", recorder.Body.String()) + } + + if recorder.Header().Get("Content-Type") != "video/mp4" { + t.Fatalf("unexpected content type: %q", recorder.Header().Get("Content-Type")) + } +} + +func TestVideosContentHandlerDownloadsGeneratedVideo(t *testing.T) { + gin.SetMode(gin.TestMode) + + videoServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + _, _ = w.Write([]byte("video-data")) + }), + ) + defer videoServer.Close() + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequestWithContext( + context.Background(), + http.MethodGet, + "/v1/videos/request-123/content", + nil, + ) + + m := meta.NewMeta( + &coremodel.Channel{}, + mode.VideosContent, + "Wan-AI/Wan2.2-T2V-A14B", + coremodel.ModelConfig{}, + meta.WithVideoID("request-123"), + ) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{ + "status":"Succeed", + "results":{"videos":[{"url":"` + videoServer.URL + `"}]} + }`))), + } + + result, adaptorErr := VideoContentHandler(m, ctx, resp) + if adaptorErr != nil { + t.Fatalf("VideoContentHandler returned error: %v", adaptorErr) + } + + if result.UpstreamID != "request-123" { + t.Fatalf("expected upstream id, got %q", result.UpstreamID) + } + + if recorder.Body.String() != "video-data" { + t.Fatalf("unexpected video body: %q", recorder.Body.String()) + } +} + +func TestFetchAsyncUsageUsesReturnedVideoCount(t *testing.T) { + statusServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "status":"Succeed", + "results":{ + "videos":[ + {"url":"https://example.com/one.mp4"}, + {"url":"https://example.com/two.mp4"} + ] + } + }`)) + }), + ) + defer statusServer.Close() + + sfAdaptor := &Adaptor{} + + usage, _, done, err := sfAdaptor.FetchAsyncUsage( + context.Background(), + adaptor.AsyncUsageRequest{ + Channel: &coremodel.Channel{ + BaseURL: statusServer.URL, + Key: "test-key", + }, + Info: &coremodel.AsyncUsageInfo{ + Mode: int(mode.Videos), + UpstreamID: "request-123", + Usage: coremodel.Usage{ + OutputTokens: 8, + TotalTokens: 8, + }, + }, + }, + ) + if err != nil { + t.Fatalf("FetchAsyncUsage returned error: %v", err) + } + + if !done { + t.Fatalf("expected async usage to be done") + } + + if usage.OutputTokens != 2 || usage.TotalTokens != 2 { + t.Fatalf("expected usage to count returned videos, got %#v", usage) + } +} + +func newJSONRequest(t *testing.T, path, body string) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + path, + bytes.NewReader([]byte(body)), + ) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + + return req +} + +func readJSONBody(t *testing.T, body io.Reader) map[string]any { + t.Helper() + + data, err := io.ReadAll(body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("failed to unmarshal body %s: %v", string(data), err) + } + + return got +} + +func assertMapValue(t *testing.T, got map[string]any, key, want string) { + t.Helper() + + if got[key] != want { + t.Fatalf("expected %s=%q, got %#v", key, want, got[key]) + } +} + +func assertMapNumber(t *testing.T, got map[string]any, key string, want int) { + t.Helper() + + value, ok := got[key].(float64) + if !ok || int(value) != want { + t.Fatalf("expected %s=%d, got %#v", key, want, got[key]) + } +} + +func assertSiliconFlowContentType(t *testing.T, got any, wantType string) { + t.Helper() + + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected content object, got %T", got) + } + + if gotMap["type"] != wantType { + t.Fatalf("expected content type %q, got %#v", wantType, gotMap["type"]) + } +} diff --git a/core/relay/adaptor/siliconflow/async_usage.go b/core/relay/adaptor/siliconflow/async_usage.go new file mode 100644 index 00000000..22c4786a --- /dev/null +++ b/core/relay/adaptor/siliconflow/async_usage.go @@ -0,0 +1,150 @@ +package siliconflow + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/bytedance/sonic" + "github.com/labring/aiproxy/core/common" + "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor" + "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" + relayutils "github.com/labring/aiproxy/core/relay/utils" +) + +var _ adaptor.AsyncUsageFetcher = (*Adaptor)(nil) + +func (a *Adaptor) FetchAsyncUsage( + ctx context.Context, + request adaptor.AsyncUsageRequest, +) (model.Usage, model.UsageContext, bool, error) { + info := request.Info + if info == nil { + return model.Usage{}, model.UsageContext{}, false, errors.New("async usage info is nil") + } + + switch mode.Mode(info.Mode) { + case mode.VideoGenerationsJobs, mode.Videos: + default: + return model.Usage{}, model.UsageContext{}, false, fmt.Errorf( + "unsupported async usage mode: %d", + info.Mode, + ) + } + + response, err := a.fetchVideoStatus(ctx, request.Channel, info) + if err != nil { + return model.Usage{}, model.UsageContext{}, false, err + } + + switch siliconFlowVideoStatusToOpenAI(response.Status) { + case relaymodel.VideoStatusCompleted: + outputTokens := siliconFlowVideoOutputTokens(response) + + return model.Usage{ + OutputTokens: model.ZeroNullInt64(outputTokens), + TotalTokens: model.ZeroNullInt64(outputTokens), + }, model.UsageContext{}, true, nil + case relaymodel.VideoStatusQueued, relaymodel.VideoStatusInProgress: + return model.Usage{}, model.UsageContext{}, false, nil + default: + return model.Usage{}, model.UsageContext{}, true, fmt.Errorf( + "siliconflow video task ended with status %q: %s", + response.Status, + response.Reason, + ) + } +} + +func (a *Adaptor) fetchVideoStatus( + ctx context.Context, + channel *model.Channel, + info *model.AsyncUsageInfo, +) (*videoStatusResponse, error) { + if info.UpstreamID == "" { + return nil, errors.New("upstream id is empty") + } + + baseURL := a.DefaultBaseURL() + if info.BaseURL != "" { + baseURL = info.BaseURL + } else if channel != nil && channel.BaseURL != "" { + baseURL = channel.BaseURL + } + + requestURL, err := url.JoinPath(baseURL, "/video/status") + if err != nil { + return nil, fmt.Errorf("build video status url: %w", err) + } + + body, err := sonic.Marshal(videoStatusRequest{RequestID: info.UpstreamID}) + if err != nil { + return nil, fmt.Errorf("marshal video status request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + if channel != nil { + req.Header.Set("Authorization", "Bearer "+channel.Key) + } + + var ( + proxyURL string + skipTLSVerify bool + ) + if channel != nil { + proxyURL = channel.ProxyURL + skipTLSVerify = channel.SkipTLSVerify + } + + client, err := relayutils.LoadHTTPClientWithTLSConfigE(0, proxyURL, skipTLSVerify) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var response videoStatusResponse + if err := common.UnmarshalResponse(resp, &response); err != nil { + return nil, fmt.Errorf("decode video status response: %w", err) + } + + return &response, nil +} + +func siliconFlowVideoOutputTokens(response *videoStatusResponse) int64 { + if response == nil { + return 1 + } + + var count int64 + for _, video := range response.Results.Videos { + if video.URL != "" { + count++ + } + } + + if count > 0 { + return count + } + + return 1 +} diff --git a/core/relay/adaptor/siliconflow/chat.go b/core/relay/adaptor/siliconflow/chat.go new file mode 100644 index 00000000..aa490c32 --- /dev/null +++ b/core/relay/adaptor/siliconflow/chat.go @@ -0,0 +1,226 @@ +package siliconflow + +import ( + "strings" + + "github.com/bytedance/sonic/ast" + relaymodel "github.com/labring/aiproxy/core/relay/model" +) + +func patchChatMultimodalContent(node *ast.Node) error { + messagesNode := node.Get("messages") + if !messagesNode.Exists() || messagesNode.TypeSafe() != ast.V_ARRAY { + return nil + } + + var patchErr error + + err := messagesNode.ForEach(func(_ ast.Sequence, messageNode *ast.Node) bool { + contentNode := messageNode.Get("content") + if !contentNode.Exists() || contentNode.TypeSafe() != ast.V_ARRAY { + return true + } + + patchErr = patchChatContentItems(contentNode) + + return patchErr == nil + }) + if err != nil { + return err + } + + return patchErr +} + +func patchSiliconFlowMultimodalContent(openAIReq *relaymodel.GeneralOpenAIRequest) error { + for i := range openAIReq.Messages { + patchSiliconFlowMessageContent(&openAIReq.Messages[i]) + } + + return nil +} + +func patchSiliconFlowMessageContent(message *relaymodel.Message) { + contentParts := message.ParseContent() + if len(contentParts) == 0 { + return + } + + patchedParts := make([]map[string]any, 0, len(contentParts)) + for _, part := range contentParts { + switch part.Type { + case relaymodel.ContentTypeText: + patchedParts = append(patchedParts, map[string]any{ + "type": "text", + "text": part.Text, + }) + case relaymodel.ContentTypeImageURL: + if part.ImageURL == nil { + continue + } + + patchedParts = append(patchedParts, map[string]any{ + "type": "image_url", + "image_url": part.ImageURL, + }) + case relaymodel.ContentTypeInputAudio: + if part.InputAudio == nil { + continue + } + + patchedParts = append(patchedParts, map[string]any{ + "type": "audio_url", + "audio_url": map[string]string{ + "url": openAIInputAudioDataURL(part.InputAudio), + }, + }) + case relaymodel.ContentTypeVideoURL: + if part.VideoURL == nil { + continue + } + + patchedParts = append(patchedParts, map[string]any{ + "type": "video_url", + "video_url": part.VideoURL, + }) + } + } + + if len(patchedParts) == 1 && patchedParts[0]["type"] == relaymodel.ContentTypeText { + message.Content = patchedParts[0]["text"] + return + } + + message.Content = patchedParts +} + +func patchChatContentItems(contentNode *ast.Node) error { + return contentNode.ForEach(func(_ ast.Sequence, item *ast.Node) bool { + if item == nil || item.TypeSafe() != ast.V_OBJECT { + return true + } + + contentType, ok, err := chatContentType(item) + if err != nil { + return false + } + + if !ok || contentType != "input_audio" { + return true + } + + audioURL, ok, err := openAIInputAudioURL(item) + if err != nil { + return false + } + + if ok { + *item = newSiliconFlowAudioURLContent(audioURL) + } + + return true + }) +} + +func chatContentType(item *ast.Node) (string, bool, error) { + typeNode := item.Get("type") + if !typeNode.Exists() { + return "", false, nil + } + + contentType, err := typeNode.String() + if err != nil { + return "", false, err + } + + return contentType, true, nil +} + +func openAIInputAudioURL(item *ast.Node) (string, bool, error) { + audioNode := item.Get("input_audio") + if !audioNode.Exists() || audioNode.TypeSafe() != ast.V_OBJECT { + return "", false, nil + } + + urlNode := audioNode.Get("url") + if urlNode.Exists() && urlNode.TypeSafe() == ast.V_STRING { + audioURL, err := urlNode.String() + if err != nil { + return "", false, err + } + + if audioURL != "" { + return audioURL, true, nil + } + } + + dataNode := audioNode.Get("data") + if !dataNode.Exists() || dataNode.TypeSafe() != ast.V_STRING { + return "", false, nil + } + + data, err := dataNode.String() + if err != nil { + return "", false, err + } + + if data == "" { + return "", false, nil + } + + if strings.HasPrefix(data, "data:audio/") { + return data, true, nil + } + + format := "wav" + + formatNode := audioNode.Get("format") + if formatNode.Exists() && formatNode.TypeSafe() == ast.V_STRING { + format, err = formatNode.String() + if err != nil { + return "", false, err + } + } + + format = strings.TrimPrefix(strings.TrimSpace(strings.ToLower(format)), ".") + if format == "" { + format = "wav" + } + + return "data:audio/" + format + ";base64," + data, true, nil +} + +func openAIInputAudioDataURL(inputAudio *relaymodel.InputAudio) string { + if inputAudio == nil { + return "" + } + + if inputAudio.URL != "" { + return inputAudio.URL + } + + data := strings.TrimSpace(inputAudio.Data) + if data == "" { + return "" + } + + if strings.HasPrefix(data, "data:audio/") { + return data + } + + format := strings.TrimPrefix(strings.TrimSpace(strings.ToLower(inputAudio.Format)), ".") + if format == "" { + format = "wav" + } + + return "data:audio/" + format + ";base64," + data +} + +func newSiliconFlowAudioURLContent(audioURL string) ast.Node { + return ast.NewObject([]ast.Pair{ + ast.NewPair("type", ast.NewString("audio_url")), + ast.NewPair("audio_url", ast.NewObject([]ast.Pair{ + ast.NewPair("url", ast.NewString(audioURL)), + })), + }) +} diff --git a/core/relay/adaptor/siliconflow/chat_test.go b/core/relay/adaptor/siliconflow/chat_test.go new file mode 100644 index 00000000..b698578d --- /dev/null +++ b/core/relay/adaptor/siliconflow/chat_test.go @@ -0,0 +1,171 @@ +package siliconflow_test + +import ( + "bytes" + "context" + "net/http" + "testing" + + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor/siliconflow" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" +) + +func TestConvertRequestChatPatchesInputAudioToAudioURL(t *testing.T) { + adaptor := &siliconflow.Adaptor{} + m := meta.NewMeta( + nil, + mode.ChatCompletions, + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + coremodel.ModelConfig{}, + ) + + req := newChatRequest(t, []byte(`{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [ + { + "role": "user", + "content": [ + {"type":"text","text":"Transcribe this audio."}, + {"type":"input_audio","input_audio":{"data":"QUJD","format":"wav"}}, + {"type":"input_audio","input_audio":{"url":"https://example.com/audio.mp3"}}, + { + "type":"video_url", + "video_url":{ + "url":"https://example.com/video.mp4", + "detail":"low", + "max_frames":8, + "fps":1 + } + } + ] + } + ], + "stream": true + }`)) + + result, err := adaptor.ConvertRequest(m, nil, req) + if err != nil { + t.Fatalf("ConvertRequest returned error: %v", err) + } + + got := readConvertResultBody(t, result.Body) + if got["model"] != "Qwen/Qwen3-Omni-30B-A3B-Instruct" { + t.Fatalf("expected actual model, got %#v", got["model"]) + } + + messages, ok := got["messages"].([]any) + if !ok || len(messages) != 1 { + t.Fatalf("expected one message, got %#v", got["messages"]) + } + + message, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("expected message object, got %#v", messages[0]) + } + + content, ok := message["content"].([]any) + if !ok || len(content) != 4 { + t.Fatalf("expected four content items, got %#v", message["content"]) + } + + assertSiliconFlowTextContent(t, content[0], "Transcribe this audio.") + assertSiliconFlowAudioURL(t, content[1], "data:audio/wav;base64,QUJD") + assertSiliconFlowAudioURL(t, content[2], "https://example.com/audio.mp3") + assertSiliconFlowVideoURL(t, content[3]) + + streamOptions, ok := got["stream_options"].(map[string]any) + if !ok || streamOptions["include_usage"] != true { + t.Fatalf("expected include_usage stream_options, got %#v", got["stream_options"]) + } +} + +func newChatRequest(t *testing.T, body []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/v1/chat/completions", + bytes.NewReader(body), + ) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + return req +} + +func assertSiliconFlowAudioURL(t *testing.T, got any, wantURL string) { + t.Helper() + + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected audio content object, got %T", got) + } + + if gotMap["type"] != "audio_url" { + t.Fatalf("expected type=audio_url, got %#v", gotMap["type"]) + } + + audioURL, ok := gotMap["audio_url"].(map[string]any) + if !ok { + t.Fatalf("expected audio_url object, got %#v", gotMap["audio_url"]) + } + + if audioURL["url"] != wantURL { + t.Fatalf("expected audio url %q, got %#v", wantURL, audioURL["url"]) + } +} + +func assertSiliconFlowTextContent(t *testing.T, got any, wantText string) { + t.Helper() + + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected text content object, got %T", got) + } + + if gotMap["type"] != "text" { + t.Fatalf("expected type=text, got %#v", gotMap["type"]) + } + + if gotMap["text"] != wantText { + t.Fatalf("expected text %q, got %#v", wantText, gotMap["text"]) + } +} + +func assertSiliconFlowVideoURL(t *testing.T, got any) { + t.Helper() + + gotMap, ok := got.(map[string]any) + if !ok { + t.Fatalf("expected video content object, got %T", got) + } + + if gotMap["type"] != "video_url" { + t.Fatalf("expected type=video_url, got %#v", gotMap["type"]) + } + + videoURL, ok := gotMap["video_url"].(map[string]any) + if !ok { + t.Fatalf("expected video_url object, got %#v", gotMap["video_url"]) + } + + if videoURL["url"] != "https://example.com/video.mp4" { + t.Fatalf("expected video url, got %#v", videoURL["url"]) + } + + if videoURL["detail"] != "low" { + t.Fatalf("expected video detail, got %#v", videoURL["detail"]) + } + + if videoURL["max_frames"] != float64(8) { + t.Fatalf("expected max_frames=8, got %#v", videoURL["max_frames"]) + } + + if videoURL["fps"] != float64(1) { + t.Fatalf("expected fps=1, got %#v", videoURL["fps"]) + } +} diff --git a/core/relay/adaptor/siliconflow/image.go b/core/relay/adaptor/siliconflow/image.go index 4d2dff09..367227e5 100644 --- a/core/relay/adaptor/siliconflow/image.go +++ b/core/relay/adaptor/siliconflow/image.go @@ -2,13 +2,19 @@ package siliconflow import ( "bytes" - "io" "net/http" + "strconv" + "time" "github.com/bytedance/sonic" + "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" + "github.com/labring/aiproxy/core/common/image" + "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor" "github.com/labring/aiproxy/core/relay/adaptor/openai" "github.com/labring/aiproxy/core/relay/meta" + relaymodel "github.com/labring/aiproxy/core/relay/model" ) type ImageRequest struct { @@ -23,19 +29,31 @@ type ImageRequest struct { PromptEnhancement bool `json:"prompt_enhancement"` } -func ConvertImageRequest(meta *meta.Meta, request *http.Request) (http.Header, io.Reader, error) { +type imageResponse struct { + Images []imageResponseImage `json:"images"` + Timings map[string]any `json:"timings,omitempty"` + Seed int64 `json:"seed,omitempty"` +} + +type imageResponseImage struct { + URL string `json:"url"` +} + +func ConvertImageRequest(meta *meta.Meta, request *http.Request) (adaptor.ConvertResult, error) { var reqMap map[string]any err := common.UnmarshalRequestReusable(request, &reqMap) if err != nil { - return nil, nil, err + return adaptor.ConvertResult{}, err } meta.Set(openai.MetaResponseFormat, reqMap["response_format"]) reqMap["model"] = meta.ActualModel - reqMap["batch_size"] = reqMap["n"] - delete(reqMap, "n") + if _, ok := reqMap["n"]; ok { + reqMap["batch_size"] = reqMap["n"] + delete(reqMap, "n") + } if _, ok := reqMap["steps"]; ok { reqMap["num_inference_steps"] = reqMap["steps"] @@ -47,13 +65,101 @@ func ConvertImageRequest(meta *meta.Meta, request *http.Request) (http.Header, i delete(reqMap, "scale") } - reqMap["image_size"] = reqMap["size"] - delete(reqMap, "size") + if _, ok := reqMap["size"]; ok { + reqMap["image_size"] = reqMap["size"] + delete(reqMap, "size") + } data, err := sonic.Marshal(&reqMap) if err != nil { - return nil, nil, err + return adaptor.ConvertResult{}, err + } + + return adaptor.ConvertResult{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {strconv.Itoa(len(data))}, + }, + Body: bytes.NewReader(data), + }, nil +} + +func ImageHandler( + meta *meta.Meta, + c *gin.Context, + resp *http.Response, +) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, ErrorHandler(resp) + } + + defer resp.Body.Close() + + log := common.GetLogger(c) + + var sfResponse imageResponse + if err := common.UnmarshalResponse(resp, &sfResponse); err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIError( + err, + "unmarshal_response_body_failed", + http.StatusInternalServerError, + ) + } + + openaiResponse := relaymodel.ImageResponse{ + Created: time.Now().Unix(), + Data: make([]*relaymodel.ImageData, 0, len(sfResponse.Images)), + } + + for _, img := range sfResponse.Images { + openaiResponse.Data = append(openaiResponse.Data, &relaymodel.ImageData{ + URL: img.URL, + }) + } + + var err error + + if meta.GetString(openai.MetaResponseFormat) == "b64_json" { + for i := range openaiResponse.Data { + data := openaiResponse.Data[i] + if data.B64Json != "" || data.URL == "" { + continue + } + + _, data.B64Json, err = image.GetImageFromURL(c.Request.Context(), data.URL) + if err != nil { + log.Warnf( + "convert siliconflow image url to b64_json failed, keep original url: %v", + err, + ) + + continue + } + } + } + + data, err := sonic.Marshal(openaiResponse) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIError( + err, + "marshal_response_body_failed", + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + + _, err = c.Writer.Write(data) + if err != nil { + log.Warnf("write response body failed: %v", err) + } + + usage := model.Usage{ + InputTokens: meta.RequestUsage.InputTokens, + OutputTokens: meta.RequestUsage.OutputTokens, + TotalTokens: meta.RequestUsage.InputTokens + meta.RequestUsage.OutputTokens, } - return http.Header{}, bytes.NewReader(data), nil + return adaptor.DoResponseResult{Usage: usage}, nil } diff --git a/core/relay/adaptor/siliconflow/video.go b/core/relay/adaptor/siliconflow/video.go new file mode 100644 index 00000000..fcdaeb84 --- /dev/null +++ b/core/relay/adaptor/siliconflow/video.go @@ -0,0 +1,814 @@ +package siliconflow + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/gin-gonic/gin" + "github.com/labring/aiproxy/core/common" + "github.com/labring/aiproxy/core/common/image" + "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" + relaymodel "github.com/labring/aiproxy/core/relay/model" + relayutils "github.com/labring/aiproxy/core/relay/utils" +) + +const ( + metaVideoRequest = "siliconflow_video_request" + siliconFlowVideoTTL = 24 * time.Hour +) + +type videoSubmitRequest struct { + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + ImageSize string `json:"image_size,omitempty"` + Image string `json:"image,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Seed any `json:"seed,omitempty"` +} + +type videoSubmitResponse struct { + RequestID string `json:"requestId"` +} + +type videoStatusRequest struct { + RequestID string `json:"requestId"` +} + +type videoStatusResponse struct { + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + Results videoStatusResults `json:"results,omitempty"` +} + +type videoStatusResults struct { + Videos []videoStatusVideo `json:"videos,omitempty"` + Timings map[string]any `json:"timings,omitempty"` + Seed int64 `json:"seed,omitempty"` +} + +type videoStatusVideo struct { + URL string `json:"url"` +} + +func ConvertVideoRequest(meta *meta.Meta, req *http.Request) (adaptor.ConvertResult, error) { + if meta.Mode == mode.VideosRemix { + return adaptor.ConvertResult{}, errors.New("siliconflow does not support videos remix") + } + + var request videoSubmitRequest + + if strings.HasPrefix(req.Header.Get("Content-Type"), "multipart/form-data") { + parsed, err := multipartVideoSubmitRequest(req) + if err != nil { + return adaptor.ConvertResult{}, err + } + + request = parsed + } else { + var reqMap map[string]any + if err := common.UnmarshalRequestReusable(req, &reqMap); err != nil { + return adaptor.ConvertResult{}, err + } + + request = jsonVideoSubmitRequest(reqMap) + } + + request.Model = meta.ActualModel + meta.Set(metaVideoRequest, request) + + data, err := sonic.Marshal(request) + if err != nil { + return adaptor.ConvertResult{}, err + } + + return adaptor.ConvertResult{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {strconv.Itoa(len(data))}, + }, + Body: bytes.NewReader(data), + }, nil +} + +func ConvertVideoStatusRequest(meta *meta.Meta, _ *http.Request) (adaptor.ConvertResult, error) { + data, err := sonic.Marshal(videoStatusRequest{RequestID: meta.JobID}) + if err != nil { + return adaptor.ConvertResult{}, err + } + + return adaptor.ConvertResult{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {strconv.Itoa(len(data))}, + }, + Body: bytes.NewReader(data), + }, nil +} + +func ConvertVideoContentStatusRequest( + meta *meta.Meta, + _ *http.Request, +) (adaptor.ConvertResult, error) { + data, err := sonic.Marshal(videoStatusRequest{RequestID: meta.GenerationID}) + if err != nil { + return adaptor.ConvertResult{}, err + } + + return adaptor.ConvertResult{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {strconv.Itoa(len(data))}, + }, + Body: bytes.NewReader(data), + }, nil +} + +func ConvertVideosStatusRequest(meta *meta.Meta, _ *http.Request) (adaptor.ConvertResult, error) { + data, err := sonic.Marshal(videoStatusRequest{RequestID: meta.VideoID}) + if err != nil { + return adaptor.ConvertResult{}, err + } + + return adaptor.ConvertResult{ + Header: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {strconv.Itoa(len(data))}, + }, + Body: bytes.NewReader(data), + }, nil +} + +func jsonVideoSubmitRequest(reqMap map[string]any) videoSubmitRequest { + request := videoSubmitRequest{ + Prompt: stringFromMap(reqMap, "prompt"), + ImageSize: videoImageSize(reqMap), + Image: videoImage(reqMap), + NegativePrompt: stringFromMap(reqMap, "negative_prompt"), + Seed: reqMap["seed"], + } + + return request +} + +func multipartVideoSubmitRequest(req *http.Request) (videoSubmitRequest, error) { + if err := common.ParseMultipartFormWithLimit(req); err != nil { + return videoSubmitRequest{}, fmt.Errorf("parse multipart form: %w", err) + } + + request := videoSubmitRequest{ + Prompt: req.PostFormValue("prompt"), + ImageSize: strings.TrimSpace(req.PostFormValue("size")), + NegativePrompt: req.PostFormValue("negative_prompt"), + } + + if request.ImageSize == "" { + width := req.PostFormValue("width") + + height := req.PostFormValue("height") + if width != "" && height != "" { + request.ImageSize = width + "x" + height + } + } + + if seed := strings.TrimSpace(req.PostFormValue("seed")); seed != "" { + request.Seed = seed + } + + imageValue := req.PostFormValue("input_reference") + if imageValue == "" { + imageValue = req.PostFormValue("image") + } + + if imageValue != "" { + request.Image = imageValue + return request, nil + } + + imageData, err := multipartVideoImageDataURL(req.MultipartForm.File) + if err != nil { + return videoSubmitRequest{}, err + } + + request.Image = imageData + + return request, nil +} + +func videoImageSize(reqMap map[string]any) string { + if size := stringFromMap(reqMap, "size"); size != "" { + return size + } + + width, widthOK := intFromAny(reqMap["width"]) + + height, heightOK := intFromAny(reqMap["height"]) + if widthOK && heightOK && width > 0 && height > 0 { + return fmt.Sprintf("%dx%d", width, height) + } + + if imageSize := stringFromMap(reqMap, "image_size"); imageSize != "" { + return imageSize + } + + return "" +} + +func videoImage(reqMap map[string]any) string { + if inputReference := stringFromMap(reqMap, "input_reference"); inputReference != "" { + return inputReference + } + + if image := stringFromMap(reqMap, "image"); image != "" { + return image + } + + return "" +} + +func stringFromMap(reqMap map[string]any, key string) string { + value, ok := reqMap[key] + if !ok { + return "" + } + + str, ok := value.(string) + if !ok { + return "" + } + + return strings.TrimSpace(str) +} + +func intFromAny(value any) (int, bool) { + switch v := value.(type) { + case int: + return v, true + case int64: + return int(v), true + case float64: + return int(v), true + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(v)) + if err != nil { + return 0, false + } + + return parsed, true + default: + return 0, false + } +} + +func multipartVideoImageDataURL(files map[string][]*multipart.FileHeader) (string, error) { + fileHeaders := make( + []*multipart.FileHeader, + 0, + len(files["input_reference"])+len(files["image"]), + ) + fileHeaders = append(fileHeaders, files["input_reference"]...) + fileHeaders = append(fileHeaders, files["image"]...) + + if len(fileHeaders) == 0 { + return "", nil + } + + if len(fileHeaders) > 1 { + return "", errors.New("video image supports at most 1 file") + } + + return multipartImageDataURL(fileHeaders[0]) +} + +func multipartImageDataURL(fileHeader *multipart.FileHeader) (string, error) { + file, err := fileHeader.Open() + if err != nil { + return "", err + } + defer file.Close() + + data, err := io.ReadAll(common.LimitReader(file, image.MaxImageSize+1)) + if err != nil { + return "", err + } + + if len(data) > image.MaxImageSize { + return "", fmt.Errorf("image too large: max: %d", image.MaxImageSize) + } + + contentType := fileHeader.Header.Get("Content-Type") + if contentType == "" { + contentType = http.DetectContentType(data) + } + + if !image.IsImageURL(contentType) { + if ext := strings.ToLower(filepath.Ext(fileHeader.Filename)); ext != "" { + if detected := mime.TypeByExtension(ext); detected != "" { + contentType = detected + } + } + } + + if !image.IsImageURL(contentType) { + return "", errors.New("image file is not an image") + } + + contentType = image.TrimImageContentType(contentType) + + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(data), nil +} + +func VideoSubmitHandler( + meta *meta.Meta, + store adaptor.Store, + c *gin.Context, + resp *http.Response, +) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, ErrorHandler(resp) + } + + defer resp.Body.Close() + + var response videoSubmitResponse + if err := common.UnmarshalResponse(resp, &response); err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + if response.RequestID == "" { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoErrorWithMessage( + "missing requestId in siliconflow video submit response", + http.StatusInternalServerError, + ) + } + + meta.RequestUsage = siliconFlowVideoSubmitUsage() + + if meta.Mode == mode.Videos { + video := buildVideo(meta, response.RequestID, relaymodel.VideoStatusQueued, nil) + if err := saveVideoGenerationStore( + meta, + store, + response.RequestID, + time.Now().Add(siliconFlowVideoTTL), + ); err != nil { + common.GetLogger(c).Errorf("save siliconflow video store failed: %v", err) + } + + data, err := sonic.Marshal(video) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + _, _ = c.Writer.Write(data) + + return adaptor.DoResponseResult{ + UpstreamID: response.RequestID, + AsyncUsage: true, + }, nil + } + + if err := saveVideoJobStore( + meta, + store, + response.RequestID, + time.Now().Add(siliconFlowVideoTTL), + ); err != nil { + common.GetLogger(c).Errorf("save siliconflow video job store failed: %v", err) + } + + job := buildVideoJob(meta, response.RequestID, relaymodel.VideoGenerationJobStatusQueued, nil) + + data, err := sonic.Marshal(job) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + _, _ = c.Writer.Write(data) + + return adaptor.DoResponseResult{ + UpstreamID: response.RequestID, + AsyncUsage: true, + }, nil +} + +func siliconFlowVideoSubmitUsage() model.Usage { + return model.Usage{ + OutputTokens: model.ZeroNullInt64(1), + TotalTokens: model.ZeroNullInt64(1), + } +} + +func VideoStatusHandler( + meta *meta.Meta, + store adaptor.Store, + c *gin.Context, + resp *http.Response, +) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, ErrorHandler(resp) + } + + defer resp.Body.Close() + + var response videoStatusResponse + if err := common.UnmarshalResponse(resp, &response); err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + if meta.Mode == mode.VideosGet { + video := buildVideo( + meta, + meta.VideoID, + siliconFlowVideoStatusToOpenAI(response.Status), + &response, + ) + if video.Status == relaymodel.VideoStatusCompleted { + if err := saveVideoGenerationStore( + meta, + store, + video.ID, + time.Now().Add(siliconFlowVideoTTL), + ); err != nil { + common.GetLogger(c).Errorf("save siliconflow video store failed: %v", err) + } + } + + data, err := sonic.Marshal(video) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + _, _ = c.Writer.Write(data) + + return adaptor.DoResponseResult{UpstreamID: video.ID}, nil + } + + job := buildVideoJob(meta, meta.JobID, siliconFlowVideoStatus(response.Status), &response) + + if response.Status == "Succeed" { + expiresAt := time.Now().Add(siliconFlowVideoTTL) + for _, generation := range job.Generations { + if err := saveVideoGenerationStore(meta, store, generation.ID, expiresAt); err != nil { + common.GetLogger(c). + Errorf("save siliconflow video generation store failed: %v", err) + } + } + } + + data, err := sonic.Marshal(job) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + _, _ = c.Writer.Write(data) + + return adaptor.DoResponseResult{}, nil +} + +func VideoContentHandler( + meta *meta.Meta, + c *gin.Context, + resp *http.Response, +) (adaptor.DoResponseResult, adaptor.Error) { + if resp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, ErrorHandler(resp) + } + + defer resp.Body.Close() + + var response videoStatusResponse + if err := common.UnmarshalResponse(resp, &response); err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + + videoURL := firstSiliconFlowVideoURL(response.Results.Videos) + if videoURL == "" { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoErrorWithMessage( + "video url is empty", + http.StatusInternalServerError, + ) + } + + videoResp, err := fetchSiliconFlowVideoContent(c.Request.Context(), meta, videoURL) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + defer videoResp.Body.Close() + + if videoResp.StatusCode != http.StatusOK { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoErrorWithMessage( + fmt.Sprintf("unexpected video status code: %d", videoResp.StatusCode), + http.StatusInternalServerError, + ) + } + + c.Writer.Header(). + Set("Content-Type", firstNonEmptyString(videoResp.Header.Get("Content-Type"), "video/mp4")) + c.Writer.Header().Set("Content-Length", videoResp.Header.Get("Content-Length")) + _, _ = io.Copy(c.Writer, videoResp.Body) + + return adaptor.DoResponseResult{UpstreamID: siliconFlowContentUpstreamID(meta)}, nil +} + +func fetchSiliconFlowVideoContent( + ctx context.Context, + meta *meta.Meta, + videoURL string, +) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, videoURL, nil) + if err != nil { + return nil, err + } + + var ( + proxyURL string + skipTLSVerify bool + ) + if meta != nil { + proxyURL = meta.Channel.ProxyURL + skipTLSVerify = meta.Channel.SkipTLSVerify + } + + client, err := relayutils.LoadHTTPClientWithTLSConfigE(0, proxyURL, skipTLSVerify) + if err != nil { + return nil, err + } + + return client.Do(req) +} + +func firstSiliconFlowVideoURL(videos []videoStatusVideo) string { + for _, video := range videos { + if video.URL != "" { + return video.URL + } + } + + return "" +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + + return "" +} + +func siliconFlowContentUpstreamID(meta *meta.Meta) string { + if meta == nil { + return "" + } + + if meta.Mode == mode.VideosContent { + return meta.VideoID + } + + return meta.GenerationID +} + +func buildVideo( + meta *meta.Meta, + id string, + status relaymodel.VideoStatus, + response *videoStatusResponse, +) relaymodel.Video { + now := time.Now().Unix() + + var request videoSubmitRequest + if value, ok := meta.Get(metaVideoRequest); ok { + request, _ = value.(videoSubmitRequest) + } + + video := relaymodel.Video{ + ID: id, + Object: relaymodel.VideoObject, + CreatedAt: now, + Status: status, + Model: meta.OriginModel, + Prompt: request.Prompt, + Size: request.ImageSize, + } + + switch video.Status { + case relaymodel.VideoStatusCompleted: + video.Progress = 100 + case relaymodel.VideoStatusInProgress: + video.Progress = 50 + case relaymodel.VideoStatusQueued: + video.Progress = 0 + } + + if response != nil && response.Status == "Failed" { + reason := response.Reason + if reason == "" { + reason = "failed" + } + + video.Error = map[string]any{"message": reason} + } + + return video +} + +func buildVideoJob( + meta *meta.Meta, + id string, + status relaymodel.VideoGenerationJobStatus, + response *videoStatusResponse, +) relaymodel.VideoGenerationJob { + now := time.Now().Unix() + expiresAt := now + int64((24 * time.Hour).Seconds()) + + var request videoSubmitRequest + if value, ok := meta.Get(metaVideoRequest); ok { + request, _ = value.(videoSubmitRequest) + } + + job := relaymodel.VideoGenerationJob{ + Object: relaymodel.VideoGenerationJobObject, + ID: id, + Status: status, + CreatedAt: now, + ExpiresAt: &expiresAt, + Generations: []relaymodel.VideoGenerations{}, + Prompt: request.Prompt, + Model: meta.OriginModel, + NVariants: 1, + } + + if request.ImageSize != "" { + job.Width, job.Height = parseSize(request.ImageSize) + } + + if response == nil { + return job + } + + if response.Status == "Succeed" || response.Status == "Failed" { + job.FinishedAt = &now + } + + if response.Status == "Failed" { + reason := response.Reason + if reason == "" { + reason = "failed" + } + + job.FinishReason = &reason + } + + for _, video := range response.Results.Videos { + if video.URL == "" { + continue + } + + job.Generations = append(job.Generations, relaymodel.VideoGenerations{ + Object: relaymodel.VideoGenerationObject, + ID: id, + JobID: id, + CreatedAt: now, + Width: job.Width, + Height: job.Height, + Prompt: job.Prompt, + }) + + break + } + + return job +} + +func siliconFlowVideoStatus(status string) relaymodel.VideoGenerationJobStatus { + switch status { + case "Succeed": + return relaymodel.VideoGenerationJobStatusSucceeded + case "InProgress": + return relaymodel.VideoGenerationJobStatusRunning + case "Failed": + return relaymodel.VideoGenerationJobStatus("failed") + default: + return relaymodel.VideoGenerationJobStatusQueued + } +} + +func siliconFlowVideoStatusToOpenAI(status string) relaymodel.VideoStatus { + switch status { + case "Succeed": + return relaymodel.VideoStatusCompleted + case "InProgress": + return relaymodel.VideoStatusInProgress + case "Failed": + return relaymodel.VideoStatusFailed + default: + return relaymodel.VideoStatusQueued + } +} + +func parseSize(size string) (int, int) { + width, height, ok := strings.Cut(size, "x") + if !ok { + return 0, 0 + } + + parsedWidth, err := strconv.Atoi(strings.TrimSpace(width)) + if err != nil { + return 0, 0 + } + + parsedHeight, err := strconv.Atoi(strings.TrimSpace(height)) + if err != nil { + return 0, 0 + } + + return parsedWidth, parsedHeight +} + +func saveVideoJobStore( + meta *meta.Meta, + store adaptor.Store, + jobID string, + expiresAt time.Time, +) error { + if store == nil { + return nil + } + + return store.SaveStore(adaptor.StoreCache{ + ID: model.VideoJobStoreID(jobID), + GroupID: meta.Group.ID, + TokenID: meta.Token.ID, + ChannelID: meta.Channel.ID, + Model: meta.OriginModel, + ExpiresAt: expiresAt, + }) +} + +func saveVideoGenerationStore( + meta *meta.Meta, + store adaptor.Store, + generationID string, + expiresAt time.Time, +) error { + if store == nil || generationID == "" { + return nil + } + + return store.SaveStore(adaptor.StoreCache{ + ID: model.VideoGenerationStoreID(generationID), + GroupID: meta.Group.ID, + TokenID: meta.Token.ID, + ChannelID: meta.Channel.ID, + Model: meta.OriginModel, + ExpiresAt: expiresAt, + }) +} From e86a89c1bec96ee8a4745a31f85b58e6d3117d8c Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Fri, 22 May 2026 01:25:41 +0800 Subject: [PATCH 2/2] feat: siliconflow support audio image video --- core/relay/adaptor/siliconflow/image.go | 55 +++++++++++++++++-------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/core/relay/adaptor/siliconflow/image.go b/core/relay/adaptor/siliconflow/image.go index 367227e5..dee54047 100644 --- a/core/relay/adaptor/siliconflow/image.go +++ b/core/relay/adaptor/siliconflow/image.go @@ -2,11 +2,13 @@ package siliconflow import ( "bytes" + "errors" "net/http" "strconv" "time" "github.com/bytedance/sonic" + "github.com/bytedance/sonic/ast" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" "github.com/labring/aiproxy/core/common/image" @@ -40,37 +42,39 @@ type imageResponseImage struct { } func ConvertImageRequest(meta *meta.Meta, request *http.Request) (adaptor.ConvertResult, error) { - var reqMap map[string]any - - err := common.UnmarshalRequestReusable(request, &reqMap) + node, err := common.UnmarshalRequest2NodeReusable(request) if err != nil { return adaptor.ConvertResult{}, err } - meta.Set(openai.MetaResponseFormat, reqMap["response_format"]) + responseFormat, err := node.Get("response_format").String() + if err != nil && !errors.Is(err, ast.ErrNotExist) { + return adaptor.ConvertResult{}, err + } - reqMap["model"] = meta.ActualModel - if _, ok := reqMap["n"]; ok { - reqMap["batch_size"] = reqMap["n"] - delete(reqMap, "n") + meta.Set(openai.MetaResponseFormat, responseFormat) + + if _, err := node.Set("model", ast.NewString(meta.ActualModel)); err != nil { + return adaptor.ConvertResult{}, err } - if _, ok := reqMap["steps"]; ok { - reqMap["num_inference_steps"] = reqMap["steps"] - delete(reqMap, "steps") + if err := renameImageRequestField(&node, "n", "batch_size"); err != nil { + return adaptor.ConvertResult{}, err + } + + if err := renameImageRequestField(&node, "steps", "num_inference_steps"); err != nil { + return adaptor.ConvertResult{}, err } - if _, ok := reqMap["scale"]; ok { - reqMap["guidance_scale"] = reqMap["scale"] - delete(reqMap, "scale") + if err := renameImageRequestField(&node, "scale", "guidance_scale"); err != nil { + return adaptor.ConvertResult{}, err } - if _, ok := reqMap["size"]; ok { - reqMap["image_size"] = reqMap["size"] - delete(reqMap, "size") + if err := renameImageRequestField(&node, "size", "image_size"); err != nil { + return adaptor.ConvertResult{}, err } - data, err := sonic.Marshal(&reqMap) + data, err := node.MarshalJSON() if err != nil { return adaptor.ConvertResult{}, err } @@ -84,6 +88,21 @@ func ConvertImageRequest(meta *meta.Meta, request *http.Request) (adaptor.Conver }, nil } +func renameImageRequestField(node *ast.Node, oldKey, newKey string) error { + value := node.Get(oldKey) + if !value.Exists() { + return nil + } + + if _, err := node.Set(newKey, *value); err != nil { + return err + } + + _, err := node.Unset(oldKey) + + return err +} + func ImageHandler( meta *meta.Meta, c *gin.Context,