From cbf8cc562a548e844b418a1520519b87ed3d1bfb Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 14 Apr 2026 12:14:35 +0000 Subject: [PATCH] feat: wire transcription for llama.cpp, add streaming support Signed-off-by: Ettore Di Giacinto --- .github/workflows/test-extra.yml | 17 ++ Makefile | 17 ++ backend/backend.proto | 11 ++ backend/cpp/llama-cpp/Makefile | 2 +- backend/cpp/llama-cpp/grpc-server.cpp | 136 ++++++++++++++++ backend/go/voxtral/govoxtral.go | 1 + backend/go/whisper/gowhisper.go | 8 + core/backend/transcript.go | 164 +++++++++++++++----- core/http/endpoints/openai/transcription.go | 131 +++++++++++++++- core/schema/transcription.go | 2 + core/services/nodes/file_staging_client.go | 15 ++ core/services/nodes/health_mock_test.go | 3 + core/services/nodes/inflight.go | 5 + core/services/nodes/inflight_test.go | 4 + pkg/grpc/backend.go | 1 + pkg/grpc/base/base.go | 4 + pkg/grpc/client.go | 44 ++++++ pkg/grpc/embed.go | 46 ++++++ pkg/grpc/interface.go | 1 + pkg/grpc/server.go | 34 +++- pkg/model/connection_evicting_client.go | 6 + tests/e2e-backends/backend_test.go | 135 +++++++++++++--- 22 files changed, 719 insertions(+), 68 deletions(-) diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index afeebea82dc2..62a1fa3262b2 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -485,6 +485,23 @@ jobs: - name: Build llama-cpp backend image and run gRPC e2e tests run: | make test-extra-backend-llama-cpp + tests-llama-cpp-grpc-transcription: + needs: detect-changes + if: needs.detect-changes.outputs.llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true' + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.4' + - name: Build llama-cpp backend image and run audio transcription gRPC e2e tests + run: | + make test-extra-backend-llama-cpp-transcription tests-ik-llama-cpp-grpc: needs: detect-changes if: needs.detect-changes.outputs.ik-llama-cpp == 'true' || needs.detect-changes.outputs.run-all == 'true' diff --git a/Makefile b/Makefile index 5ef6062978d0..14cb1b0ee5d7 100644 --- a/Makefile +++ b/Makefile @@ -493,6 +493,10 @@ test-extra-backend: protogen-go BACKEND_TEST_MODEL_URL="$${BACKEND_TEST_MODEL_URL:-$(BACKEND_TEST_MODEL_URL)}" \ BACKEND_TEST_MODEL_FILE="$$BACKEND_TEST_MODEL_FILE" \ BACKEND_TEST_MODEL_NAME="$$BACKEND_TEST_MODEL_NAME" \ + BACKEND_TEST_MMPROJ_URL="$$BACKEND_TEST_MMPROJ_URL" \ + BACKEND_TEST_MMPROJ_FILE="$$BACKEND_TEST_MMPROJ_FILE" \ + BACKEND_TEST_AUDIO_URL="$$BACKEND_TEST_AUDIO_URL" \ + BACKEND_TEST_AUDIO_FILE="$$BACKEND_TEST_AUDIO_FILE" \ BACKEND_TEST_CAPS="$$BACKEND_TEST_CAPS" \ BACKEND_TEST_PROMPT="$$BACKEND_TEST_PROMPT" \ BACKEND_TEST_OPTIONS="$$BACKEND_TEST_OPTIONS" \ @@ -507,6 +511,19 @@ test-extra-backend-llama-cpp: docker-build-llama-cpp test-extra-backend-ik-llama-cpp: docker-build-ik-llama-cpp BACKEND_IMAGE=local-ai-backend:ik-llama-cpp $(MAKE) test-extra-backend +## Audio transcription wrapper for the llama-cpp backend. +## Drives the new AudioTranscription / AudioTranscriptionStream RPCs against +## ggml-org/Qwen3-ASR-0.6B-GGUF (a small ASR model that requires its mmproj +## audio encoder companion). The audio fixture is a short public-domain +## "jfk.wav" clip ggml-org bundles with whisper.cpp's CI assets. +test-extra-backend-llama-cpp-transcription: docker-build-llama-cpp + BACKEND_IMAGE=local-ai-backend:llama-cpp \ + BACKEND_TEST_MODEL_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/Qwen3-ASR-0.6B-Q8_0.gguf \ + BACKEND_TEST_MMPROJ_URL=https://huggingface.co/ggml-org/Qwen3-ASR-0.6B-GGUF/resolve/main/mmproj-Qwen3-ASR-0.6B-Q8_0.gguf \ + BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \ + BACKEND_TEST_CAPS=health,load,transcription \ + $(MAKE) test-extra-backend + ## vllm is resolved from a HuggingFace model id (no file download) and ## exercises Predict + streaming + tool-call extraction via the hermes parser. ## Requires a host CPU with the SIMD instructions the prebuilt vllm CPU diff --git a/backend/backend.proto b/backend/backend.proto index 078b2edc120f..d10e63e8faef 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -17,6 +17,7 @@ service Backend { rpc GenerateImage(GenerateImageRequest) returns (Result) {} rpc GenerateVideo(GenerateVideoRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} + rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {} rpc TTS(TTSRequest) returns (Result) {} rpc TTSStream(TTSRequest) returns (stream Reply) {} rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} @@ -322,11 +323,21 @@ message TranscriptRequest { bool translate = 5; bool diarize = 6; string prompt = 7; + float temperature = 8; + repeated string timestamp_granularities = 9; + bool stream = 10; } message TranscriptResult { repeated TranscriptSegment segments = 1; string text = 2; + string language = 3; + float duration = 4; +} + +message TranscriptStreamResponse { + string delta = 1; + TranscriptResult final_result = 2; } message TranscriptSegment { diff --git a/backend/cpp/llama-cpp/Makefile b/backend/cpp/llama-cpp/Makefile index 04c42fe80326..b33139127e54 100644 --- a/backend/cpp/llama-cpp/Makefile +++ b/backend/cpp/llama-cpp/Makefile @@ -1,5 +1,5 @@ -LLAMA_VERSION?=e97492369888f5311e4d1f3beb325a36bbed70e9 +LLAMA_VERSION?=6a6780a232b73fe44799b0c0d5f01c61612f1b79 LLAMA_REPO?=https://github.com/ggerganov/llama.cpp CMAKE_ARGS?= diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index 3ba9fffebeaf..fe7350528348 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -76,6 +78,27 @@ static grpc::Status checkAuth(grpc::ServerContext* context) { return grpc::Status(grpc::StatusCode::UNAUTHENTICATED, "invalid token"); } +// Minimal base64 encoder. The C++ backend already pulls in base64_decode from +// llama.cpp's server-common.cpp, but no encoder is exposed — and we need one to +// hand audio bytes to the existing PredictOptions.audios path (which expects +// base64-encoded strings, just like images). +static std::string base64_encode_bytes(const unsigned char* data, size_t len) { + static const char tbl[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string out; + out.reserve(((len + 2) / 3) * 4); + for (size_t i = 0; i < len; i += 3) { + uint32_t triple = (uint32_t(data[i]) << 16); + if (i + 1 < len) triple |= (uint32_t(data[i + 1]) << 8); + if (i + 2 < len) triple |= uint32_t(data[i + 2]); + out.push_back(tbl[(triple >> 18) & 0x3F]); + out.push_back(tbl[(triple >> 12) & 0x3F]); + out.push_back(i + 1 < len ? tbl[(triple >> 6) & 0x3F] : '='); + out.push_back(i + 2 < len ? tbl[triple & 0x3F] : '='); + } + return out; +} + // END LocalAI @@ -2931,6 +2954,119 @@ class BackendServiceImpl final : public backend::Backend::Service { return grpc::Status::OK; } + + // runTranscriptionAsCompletion implements OAI /v1/audio/transcriptions on + // top of the existing chat-completion + multimodal-audio pipeline, exactly + // the way upstream llama.cpp's server does it (see + // tools/server/server-context.cpp post_transcriptions_oai → forwards into + // handle_completions_impl with a single user message attaching the audio + // file via the mtmd marker). + // + // We synthesize a backend::PredictOptions with one user message + // ("Transcribe audio to text" + optional language hint) and the audio + // bytes attached via the existing PredictOptions.audios field, then + // delegate to our own Predict() handler. This keeps every multimodal + // codepath identical to the chat path and avoids duplicating ~700 lines + // of task-construction logic. + grpc::Status runTranscriptionAsCompletion(grpc::ServerContext* context, + const backend::TranscriptRequest* request, + backend::Reply* out_reply) { + if (params_base.model.path.empty()) { + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); + } + if (request->dst().empty()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "dst (audio file path) is required"); + } + + // Read audio bytes from the path LocalAI's HTTP layer wrote. + std::ifstream f(request->dst(), std::ios::binary); + if (!f.is_open()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "failed to open audio file: " + request->dst()); + } + std::vector bytes((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + f.close(); + if (bytes.empty()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "audio file is empty: " + request->dst()); + } + + std::string b64 = base64_encode_bytes(bytes.data(), bytes.size()); + + // Build the same prompt upstream uses in convert_transcriptions_to_chatcmpl. + std::string user_prompt = "Transcribe audio to text"; + if (!request->language().empty()) { + user_prompt += " (language: " + request->language() + ")"; + } + if (!request->prompt().empty()) { + // Optional context hint from the caller. + user_prompt += "\n" + request->prompt(); + } + + backend::PredictOptions synthetic; + synthetic.set_usetokenizertemplate(true); + synthetic.set_temperature(request->temperature()); + // Generation length: leave at 0 so parse_options uses -1 (model default). + // The model's stop tokens / EOS handle termination naturally for ASR. + backend::Message* msg = synthetic.add_messages(); + msg->set_role("user"); + msg->set_content(user_prompt); + synthetic.add_audios(b64); + + return Predict(context, &synthetic, out_reply); + } + + grpc::Status AudioTranscription(ServerContext* context, + const backend::TranscriptRequest* request, + backend::TranscriptResult* response) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; + + backend::Reply reply; + grpc::Status st = runTranscriptionAsCompletion(context, request, &reply); + if (!st.ok()) { + return st; + } + response->set_text(reply.message()); + if (!request->language().empty()) { + response->set_language(request->language()); + } + return grpc::Status::OK; + } + + grpc::Status AudioTranscriptionStream(ServerContext* context, + const backend::TranscriptRequest* request, + grpc::ServerWriter* writer) override { + auto auth = checkAuth(context); + if (!auth.ok()) return auth; + + // Buffered streaming: run the transcription as a normal chat + // completion, then emit one delta + one final event. Real + // token-by-token streaming would require refactoring PredictStream's + // 700-line writer-coupled body; the HTTP/SSE contract is identical + // either way, and clients that only consume the assembled text don't + // notice the difference. + backend::Reply reply; + grpc::Status st = runTranscriptionAsCompletion(context, request, &reply); + if (!st.ok()) { + return st; + } + + const std::string& text = reply.message(); + if (!text.empty()) { + backend::TranscriptStreamResponse delta_chunk; + delta_chunk.set_delta(text); + writer->Write(delta_chunk); + } + + backend::TranscriptStreamResponse final_chunk; + backend::TranscriptResult* final_result = final_chunk.mutable_final_result(); + final_result->set_text(text); + if (!request->language().empty()) { + final_result->set_language(request->language()); + } + writer->Write(final_chunk); + return grpc::Status::OK; + } }; diff --git a/backend/go/voxtral/govoxtral.go b/backend/go/voxtral/govoxtral.go index e5d40aa6bb76..9a296a589b94 100644 --- a/backend/go/voxtral/govoxtral.go +++ b/backend/go/voxtral/govoxtral.go @@ -56,5 +56,6 @@ func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR return pb.TranscriptResult{ Segments: segments, Text: text, + Language: opts.Language, }, nil } diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index 4a6ab616202c..bf7cb7a45be0 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -120,6 +120,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR } data := buf.AsFloat32Buffer().Data + // whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate + // for the converted file produced by AudioToWav above. + var duration float32 + if buf.Format != nil && buf.Format.SampleRate > 0 { + duration = float32(len(data)) / float32(buf.Format.SampleRate) + } segsLen := uintptr(0xdeadbeef) segsLenPtr := unsafe.Pointer(&segsLen) @@ -158,5 +164,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR return pb.TranscriptResult{ Segments: segments, Text: strings.TrimSpace(text), + Language: opts.Language, + Duration: duration, }, nil } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 6aa903e4a0de..c3bfb77b4dfa 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -10,26 +10,68 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/trace" + grpcPkg "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +// TranscriptionRequest groups the parameters accepted by ModelTranscription. +// Use this so callers don't have to pass long positional arg lists when they +// only care about a subset of fields. +type TranscriptionRequest struct { + Audio string + Language string + Translate bool + Diarize bool + Prompt string + Temperature float32 + TimestampGranularities []string +} + +func (r *TranscriptionRequest) toProto(threads uint32) *proto.TranscriptRequest { + return &proto.TranscriptRequest{ + Dst: r.Audio, + Language: r.Language, + Translate: r.Translate, + Diarize: r.Diarize, + Threads: threads, + Prompt: r.Prompt, + Temperature: r.Temperature, + TimestampGranularities: r.TimestampGranularities, + } +} + +func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (grpcPkg.Backend, error) { if modelConfig.Backend == "" { modelConfig.Backend = model.WhisperBackend } - opts := ModelOptions(modelConfig, appConfig) - transcriptionModel, err := ml.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } - if transcriptionModel == nil { return nil, fmt.Errorf("could not load transcription model") } + return transcriptionModel, nil +} + +func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { + return ModelTranscriptionWithOptions(TranscriptionRequest{ + Audio: audio, + Language: language, + Translate: translate, + Diarize: diarize, + Prompt: prompt, + }, ml, modelConfig, appConfig) +} + +func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { + transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) + if err != nil { + return nil, err + } var startTime time.Time var audioSnippet map[string]any @@ -37,25 +79,18 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() // Capture audio before the backend call — the backend may delete the file. - audioSnippet = trace.AudioSnippet(audio) + audioSnippet = trace.AudioSnippet(req.Audio) } - r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: audio, - Language: language, - Translate: translate, - Diarize: diarize, - Threads: uint32(*modelConfig.Threads), - Prompt: prompt, - }) + r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads))) if err != nil { if appConfig.EnableTracing { errData := map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, + "audio_file": req.Audio, + "language": req.Language, + "translate": req.Translate, + "diarize": req.Diarize, + "prompt": req.Prompt, } if audioSnippet != nil { maps.Copy(errData, audioSnippet) @@ -66,39 +101,22 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt Type: trace.BackendTraceTranscription, ModelName: modelConfig.Name, Backend: modelConfig.Backend, - Summary: trace.TruncateString(audio, 200), + Summary: trace.TruncateString(req.Audio, 200), Error: err.Error(), Data: errData, }) } return nil, err } - tr := &schema.TranscriptionResult{ - Text: r.Text, - } - for _, s := range r.Segments { - var tks []int - for _, t := range s.Tokens { - tks = append(tks, int(t)) - } - tr.Segments = append(tr.Segments, - schema.TranscriptionSegment{ - Text: s.Text, - Id: int(s.Id), - Start: time.Duration(s.Start), - End: time.Duration(s.End), - Tokens: tks, - Speaker: s.Speaker, - }) - } + tr := transcriptResultFromProto(r) if appConfig.EnableTracing { data := map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, + "audio_file": req.Audio, + "language": req.Language, + "translate": req.Translate, + "diarize": req.Diarize, + "prompt": req.Prompt, "result_text": tr.Text, "segments_count": len(tr.Segments), } @@ -111,10 +129,70 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt Type: trace.BackendTraceTranscription, ModelName: modelConfig.Name, Backend: modelConfig.Backend, - Summary: trace.TruncateString(audio+" -> "+tr.Text, 200), + Summary: trace.TruncateString(req.Audio+" -> "+tr.Text, 200), Data: data, }) } return tr, err } + +// TranscriptionStreamChunk is a streaming event emitted by +// ModelTranscriptionStream. Either Delta carries an incremental text fragment, +// or Final carries the completed transcription as the very last event. +type TranscriptionStreamChunk struct { + Delta string + Final *schema.TranscriptionResult +} + +// ModelTranscriptionStream runs the gRPC streaming transcription RPC and +// invokes onChunk for each event the backend produces. Backends that don't +// support real streaming should still emit one terminal event with Final set, +// which the HTTP layer turns into a single delta + done SSE pair. +func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error { + transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) + if err != nil { + return err + } + + pbReq := req.toProto(uint32(*modelConfig.Threads)) + pbReq.Stream = true + + return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) { + if chunk == nil { + return + } + out := TranscriptionStreamChunk{Delta: chunk.Delta} + if chunk.FinalResult != nil { + out.Final = transcriptResultFromProto(chunk.FinalResult) + } + onChunk(out) + }) +} + +func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionResult { + if r == nil { + return &schema.TranscriptionResult{} + } + tr := &schema.TranscriptionResult{ + Text: r.Text, + Language: r.Language, + Duration: float64(r.Duration), + } + for _, s := range r.Segments { + var tks []int + for _, t := range s.Tokens { + tks = append(tks, int(t)) + } + tr.Segments = append(tr.Segments, + schema.TranscriptionSegment{ + Text: s.Text, + Id: int(s.Id), + Start: time.Duration(s.Start), + End: time.Duration(s.End), + Tokens: tks, + Speaker: s.Speaker, + }) + } + return tr +} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 335adfb2e60d..cf18e8244175 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -1,12 +1,16 @@ package openai import ( + "encoding/json" "errors" + "fmt" "io" "net/http" "os" "path" "path/filepath" + "strconv" + "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" @@ -24,6 +28,9 @@ import ( // @accept multipart/form-data // @Param model formData string true "model" // @Param file formData file true "file" +// @Param temperature formData number false "sampling temperature" +// @Param timestamp_granularities formData []string false "timestamp granularities (word, segment)" +// @Param stream formData boolean false "stream partial results as SSE" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { @@ -42,6 +49,38 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app prompt := c.FormValue("prompt") responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format")) + // OpenAI accepts `temperature` as a string in multipart form. Tolerate + // missing/invalid values rather than failing the whole request. + var temperature float32 + if v := c.FormValue("temperature"); v != "" { + if t, err := strconv.ParseFloat(v, 32); err == nil { + temperature = float32(t) + } + } + + // timestamp_granularities[] is a multi-value form field per the OpenAI spec. + // Echo exposes all values for a key via FormParams. + var timestampGranularities []string + if form, err := c.FormParams(); err == nil { + for _, key := range []string{"timestamp_granularities[]", "timestamp_granularities"} { + if vals, ok := form[key]; ok { + for _, v := range vals { + v = strings.TrimSpace(v) + if v != "" { + timestampGranularities = append(timestampGranularities, v) + } + } + } + } + } + + stream := false + if v := c.FormValue("stream"); v != "" { + if b, err := strconv.ParseBool(v); err == nil { + stream = b + } + } + // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { @@ -73,7 +112,21 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app xlog.Debug("Audio file copied", "dst", dst) - tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig) + req := backend.TranscriptionRequest{ + Audio: dst, + Language: input.Language, + Translate: input.Translate, + Diarize: diarize, + Prompt: prompt, + Temperature: temperature, + TimestampGranularities: timestampGranularities, + } + + if stream { + return streamTranscription(c, req, ml, *config, appConfig) + } + + tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig) if err != nil { return err } @@ -93,3 +146,79 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app } } } + +// streamTranscription emits OpenAI-format SSE events for a transcription +// request: one `transcript.text.delta` per backend chunk, a final +// `transcript.text.done` with the assembled text, and `[DONE]`. Backends that +// can't truly stream still produce a single Final event, which we surface as +// one delta + done. +func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *model.ModelLoader, config config.ModelConfig, appConfig *config.ApplicationConfig) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) + + writeEvent := func(payload any) error { + data, err := json.Marshal(payload) + if err != nil { + return err + } + if _, err := fmt.Fprintf(c.Response().Writer, "data: %s\n\n", data); err != nil { + return err + } + c.Response().Flush() + return nil + } + + var assembled strings.Builder + var finalResult *schema.TranscriptionResult + + err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) { + if chunk.Delta != "" { + assembled.WriteString(chunk.Delta) + _ = writeEvent(map[string]any{ + "type": "transcript.text.delta", + "delta": chunk.Delta, + }) + } + if chunk.Final != nil { + finalResult = chunk.Final + } + }) + if err != nil { + errPayload := map[string]any{ + "type": "error", + "error": map[string]any{ + "message": err.Error(), + "type": "server_error", + }, + } + _ = writeEvent(errPayload) + _, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + + // Build the final event. Prefer the backend-provided final result; if the + // backend only emitted deltas, synthesize the result from what we collected. + if finalResult == nil { + finalResult = &schema.TranscriptionResult{Text: assembled.String()} + } else if finalResult.Text == "" && assembled.Len() > 0 { + finalResult.Text = assembled.String() + } + // If the backend never produced a delta but did return a final text, emit + // it as a single delta so clients always see at least one delta event. + if assembled.Len() == 0 && finalResult.Text != "" { + _ = writeEvent(map[string]any{ + "type": "transcript.text.delta", + "delta": finalResult.Text, + }) + } + _ = writeEvent(map[string]any{ + "type": "transcript.text.done", + "text": finalResult.Text, + }) + _, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil +} diff --git a/core/schema/transcription.go b/core/schema/transcription.go index dc22abe85ef3..b0d3a8eb3abe 100644 --- a/core/schema/transcription.go +++ b/core/schema/transcription.go @@ -14,4 +14,6 @@ type TranscriptionSegment struct { type TranscriptionResult struct { Segments []TranscriptionSegment `json:"segments,omitempty"` Text string `json:"text"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` } diff --git a/core/services/nodes/file_staging_client.go b/core/services/nodes/file_staging_client.go index 661f89e1859f..96561800b8ab 100644 --- a/core/services/nodes/file_staging_client.go +++ b/core/services/nodes/file_staging_client.go @@ -294,6 +294,21 @@ func (f *FileStagingClient) AudioTranscription(ctx context.Context, in *pb.Trans return f.Backend.AudioTranscription(ctx, in, opts...) } +func (f *FileStagingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, fn func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error { + reqID := requestID() + + // Stage input audio file + if in.Dst != "" && isFilePath(in.Dst) { + backendPath, _, err := f.stageInputFile(ctx, reqID, in.Dst, "inputs") + if err != nil { + return fmt.Errorf("staging audio for transcription stream: %w", err) + } + in.Dst = backendPath + } + + return f.Backend.AudioTranscriptionStream(ctx, in, fn, opts...) +} + func (f *FileStagingClient) ExportModel(ctx context.Context, in *pb.ExportModelRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { frontendOutputPath := in.OutputPath if frontendOutputPath != "" { diff --git a/core/services/nodes/health_mock_test.go b/core/services/nodes/health_mock_test.go index 4b49d75a327a..dd0a7a89c9bf 100644 --- a/core/services/nodes/health_mock_test.go +++ b/core/services/nodes/health_mock_test.go @@ -171,6 +171,9 @@ func (c *fakeBackendClient) Detect(_ context.Context, _ *pb.DetectOptions, _ ... func (c *fakeBackendClient) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) { return nil, nil } +func (c *fakeBackendClient) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error { + return nil +} func (c *fakeBackendClient) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) { return nil, nil } diff --git a/core/services/nodes/inflight.go b/core/services/nodes/inflight.go index ad866aed2882..5e80a847e1e7 100644 --- a/core/services/nodes/inflight.go +++ b/core/services/nodes/inflight.go @@ -105,6 +105,11 @@ func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb. return c.Backend.AudioTranscription(ctx, in, opts...) } +func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error { + defer c.track(ctx)() + return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...) +} + func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) { defer c.track(ctx)() return c.Backend.Detect(ctx, in, opts...) diff --git a/core/services/nodes/inflight_test.go b/core/services/nodes/inflight_test.go index 8266b6f215b3..21a0be81d412 100644 --- a/core/services/nodes/inflight_test.go +++ b/core/services/nodes/inflight_test.go @@ -95,6 +95,10 @@ func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.Transcript return &pb.TranscriptResult{}, nil } +func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error { + return nil +} + func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) { return &pb.TokenizationResponse{}, nil } diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 8d9818186a0e..a2111f21f871 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -55,6 +55,7 @@ type Backend interface { SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) + AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index e9bea6cf96b6..17bdad6b0d99 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -61,6 +61,10 @@ func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } +func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { + return fmt.Errorf("unimplemented") +} + func (llm *Base) TTS(*pb.TTSRequest) error { return fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 881b343d88c2..a9339437f6d8 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -352,6 +352,50 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques return client.AudioTranscription(ctx, in, opts...) } +func (c *Client) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := c.dial() + if err != nil { + return err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + stream, err := client.AudioTranscriptionStream(ctx, in, opts...) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + chunk, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + f(chunk) + } + + return nil +} + func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index c14b20427b35..1aa7e1098421 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -75,6 +75,14 @@ func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.Transcript return e.s.AudioTranscription(ctx, in) } +func (e *embedBackend) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error { + bs := &embedBackendAudioTranscriptionStream{ + ctx: ctx, + fn: f, + } + return e.s.AudioTranscriptionStream(in, bs) +} + func (e *embedBackend) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) { return e.s.TokenizeString(ctx, in) } @@ -168,6 +176,44 @@ func (e *embedBackend) Free(ctx context.Context) error { return err } +var _ pb.Backend_AudioTranscriptionStreamServer = new(embedBackendAudioTranscriptionStream) + +type embedBackendAudioTranscriptionStream struct { + ctx context.Context + fn func(chunk *pb.TranscriptStreamResponse) +} + +func (e *embedBackendAudioTranscriptionStream) Send(chunk *pb.TranscriptStreamResponse) error { + e.fn(chunk) + return nil +} + +func (e *embedBackendAudioTranscriptionStream) SetHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendAudioTranscriptionStream) SendHeader(md metadata.MD) error { + return nil +} + +func (e *embedBackendAudioTranscriptionStream) SetTrailer(md metadata.MD) { +} + +func (e *embedBackendAudioTranscriptionStream) Context() context.Context { + return e.ctx +} + +func (e *embedBackendAudioTranscriptionStream) SendMsg(m any) error { + if x, ok := m.(*pb.TranscriptStreamResponse); ok { + return e.Send(x) + } + return nil +} + +func (e *embedBackendAudioTranscriptionStream) RecvMsg(m any) error { + return nil +} + var _ pb.Backend_FineTuneProgressServer = new(embedBackendFineTuneProgressStream) type embedBackendFineTuneProgressStream struct { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 8952a22da6ef..bfd449644c4d 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -18,6 +18,7 @@ type AIModel interface { GenerateVideo(*pb.GenerateVideoRequest) error Detect(*pb.DetectOptions) (pb.DetectResponse, error) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) + AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error TTS(*pb.TTSRequest) error TTSStream(*pb.TTSRequest, chan []byte) error SoundGeneration(*pb.SoundGenerationRequest) error diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 828204efb8b9..002b4922ffaa 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -168,18 +168,42 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } tresult.Segments = append(tresult.Segments, &pb.TranscriptSegment{ - Text: s.Text, - Id: int32(s.Id), - Start: int64(s.Start), - End: int64(s.End), - Tokens: tks, + Text: s.Text, + Id: int32(s.Id), + Start: int64(s.Start), + End: int64(s.End), + Tokens: tks, + Speaker: s.Speaker, }) } tresult.Text = result.Text + tresult.Language = result.Language + tresult.Duration = result.Duration return tresult, nil } +func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Backend_AudioTranscriptionStreamServer) error { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + resultChan := make(chan *pb.TranscriptStreamResponse) + + done := make(chan bool) + go func() { + for chunk := range resultChan { + stream.Send(chunk) + } + done <- true + }() + + err := s.llm.AudioTranscriptionStream(in, resultChan) + <-done + + return err +} + func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { if s.llm.Locking() { s.llm.Lock() diff --git a/pkg/model/connection_evicting_client.go b/pkg/model/connection_evicting_client.go index 12360e3c66c0..ade1e294bad6 100644 --- a/pkg/model/connection_evicting_client.go +++ b/pkg/model/connection_evicting_client.go @@ -96,6 +96,12 @@ func (c *ConnectionEvictingClient) AudioTranscription(ctx context.Context, in *p return result, err } +func (c *ConnectionEvictingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error { + err := c.Backend.AudioTranscriptionStream(ctx, in, f, opts...) + c.checkErr(err) + return err +} + func (c *ConnectionEvictingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) { result, err := c.Backend.Detect(ctx, in, opts...) c.checkErr(err) diff --git a/tests/e2e-backends/backend_test.go b/tests/e2e-backends/backend_test.go index b6f59fd28d50..1350b4ae7b9b 100644 --- a/tests/e2e-backends/backend_test.go +++ b/tests/e2e-backends/backend_test.go @@ -35,9 +35,16 @@ import ( // // Optional: // +// BACKEND_TEST_MMPROJ_URL HTTP(S) URL of an mmproj file (audio/vision encoder) +// to download alongside the main model — required for +// multimodal models like Qwen3-ASR-0.6B-GGUF. +// BACKEND_TEST_MMPROJ_FILE Path to an already-available mmproj file. +// BACKEND_TEST_AUDIO_URL HTTP(S) URL of a sample audio file used by the +// transcription specs. +// BACKEND_TEST_AUDIO_FILE Path to an already-available sample audio file. // BACKEND_TEST_CAPS Comma-separated list of capabilities to exercise. // Supported values: health, load, predict, stream, -// embeddings, tools. +// embeddings, tools, transcription. // Defaults to "health,load,predict,stream". // A backend that only does embeddings would set this to // "health,load,embeddings"; an image/TTS backend that cannot @@ -58,12 +65,13 @@ import ( // file path to LoadModel, so GGUF, ONNX, safetensors, .bin etc. all work so // long as the backend under test accepts that format. const ( - capHealth = "health" - capLoad = "load" - capPredict = "predict" - capStream = "stream" - capEmbeddings = "embeddings" - capTools = "tools" + capHealth = "health" + capLoad = "load" + capPredict = "predict" + capStream = "stream" + capEmbeddings = "embeddings" + capTools = "tools" + capTranscription = "transcription" defaultPrompt = "The capital of France is" streamPrompt = "Once upon a time" @@ -99,17 +107,19 @@ func parseCaps() map[string]bool { var _ = Describe("Backend container", Ordered, func() { var ( - caps map[string]bool - workDir string - binaryDir string - modelFile string // set when a local file is used - modelName string // set when a HuggingFace model id is used - addr string - serverCmd *exec.Cmd - conn *grpc.ClientConn - client pb.BackendClient - prompt string - options []string + caps map[string]bool + workDir string + binaryDir string + modelFile string // set when a local file is used + modelName string // set when a HuggingFace model id is used + mmprojFile string // optional multimodal projector + audioFile string // optional audio fixture for transcription specs + addr string + serverCmd *exec.Cmd + conn *grpc.ClientConn + client pb.BackendClient + prompt string + options []string ) BeforeAll(func() { @@ -155,6 +165,25 @@ var _ = Describe("Backend container", Ordered, func() { downloadFile(modelURL, modelFile) } + // Multimodal projector (mmproj): required by audio/vision-capable + // llama.cpp models like Qwen3-ASR-0.6B-GGUF. Either file or URL. + mmprojFile = os.Getenv("BACKEND_TEST_MMPROJ_FILE") + if mmprojFile == "" { + if url := os.Getenv("BACKEND_TEST_MMPROJ_URL"); url != "" { + mmprojFile = filepath.Join(workDir, "mmproj.bin") + downloadFile(url, mmprojFile) + } + } + + // Audio fixture for the transcription specs. + audioFile = os.Getenv("BACKEND_TEST_AUDIO_FILE") + if audioFile == "" { + if url := os.Getenv("BACKEND_TEST_AUDIO_URL"); url != "" { + audioFile = filepath.Join(workDir, "sample.wav") + downloadFile(url, audioFile) + } + } + // Pick a free port and launch the backend. port, err := freeport.GetFreePort() Expect(err).NotTo(HaveOccurred()) @@ -244,6 +273,7 @@ var _ = Describe("Backend container", Ordered, func() { MMap: true, NBatch: 128, Options: options, + MMProj: mmprojFile, }) Expect(err).NotTo(HaveOccurred()) Expect(res.GetSuccess()).To(BeTrue(), "LoadModel failed: %s", res.GetMessage()) @@ -385,6 +415,75 @@ var _ = Describe("Backend container", Ordered, func() { Expect(matched).To(BeTrue(), "Expected a tool call named %q in ChatDelta.tool_calls", toolName) }) + + It("transcribes audio via AudioTranscription", func() { + if !caps[capTranscription] { + Skip("transcription capability not enabled") + } + Expect(audioFile).NotTo(BeEmpty(), + "BACKEND_TEST_AUDIO_FILE or BACKEND_TEST_AUDIO_URL must be set when transcription cap is enabled") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + res, err := client.AudioTranscription(ctx, &pb.TranscriptRequest{ + Dst: audioFile, + Threads: uint32(envInt32("BACKEND_TEST_THREADS", 4)), + Temperature: 0.0, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(res.GetText())).NotTo(BeEmpty(), + "AudioTranscription returned empty text") + GinkgoWriter.Printf("AudioTranscription: text=%q language=%q duration=%v\n", + res.GetText(), res.GetLanguage(), res.GetDuration()) + }) + + It("streams audio transcription via AudioTranscriptionStream", func() { + if !caps[capTranscription] { + Skip("transcription capability not enabled") + } + Expect(audioFile).NotTo(BeEmpty(), + "BACKEND_TEST_AUDIO_FILE or BACKEND_TEST_AUDIO_URL must be set when transcription cap is enabled") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stream, err := client.AudioTranscriptionStream(ctx, &pb.TranscriptRequest{ + Dst: audioFile, + Threads: uint32(envInt32("BACKEND_TEST_THREADS", 4)), + Temperature: 0.0, + Stream: true, + }) + Expect(err).NotTo(HaveOccurred()) + + var deltas []string + var assembled strings.Builder + var finalText string + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + Expect(err).NotTo(HaveOccurred()) + if d := chunk.GetDelta(); d != "" { + deltas = append(deltas, d) + assembled.WriteString(d) + } + if final := chunk.GetFinalResult(); final != nil && final.GetText() != "" { + finalText = final.GetText() + } + } + // At least one of: a delta arrived, or the final event carried text. + Expect(deltas).NotTo(BeEmpty(), + "AudioTranscriptionStream did not emit any deltas (assembled=%q final=%q)", + assembled.String(), finalText) + + // If both arrived, the final event should match the assembled deltas. + if finalText != "" && assembled.Len() > 0 { + Expect(finalText).To(Equal(assembled.String()), + "final transcript should match concatenated deltas") + } + GinkgoWriter.Printf("AudioTranscriptionStream: deltas=%d assembled=%q final=%q\n", + len(deltas), assembled.String(), finalText) + }) }) // extractImage runs `docker create` + `docker export` to materialise the image