diff --git a/bindings/go/.gitignore b/bindings/go/.gitignore index 036df1d3b0d..bfdd3e3c790 100644 --- a/bindings/go/.gitignore +++ b/bindings/go/.gitignore @@ -1,2 +1,4 @@ build models +samples/a13.wav +samples/benchmark_out.wav diff --git a/bindings/go/Makefile b/bindings/go/Makefile index e4436a6a291..fb57d0fc9f8 100644 --- a/bindings/go/Makefile +++ b/bindings/go/Makefile @@ -46,9 +46,19 @@ endif examples: $(EXAMPLES_DIR) +benchmark: model-small whisper modtidy +ifeq ($(UNAME_S),Darwin) + @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -bench=BenchmarkContextProcess -benchmem -run '^$$' ./pkg/whisper/... +else + @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -benchmem -run '^$$' ./pkg/whisper/... +endif + model-small: mkdir examples/go-model-download @${BUILD_DIR}/go-model-download -out models ggml-small.en.bin +model-small-tdrz: mkdir examples/go-model-download + @${BUILD_DIR}/go-model-download -out models ggml-small.en-tdrz.bin + $(EXAMPLES_DIR): mkdir whisper modtidy @echo Build example $(notdir $@) ifeq ($(UNAME_S),Darwin) @@ -57,6 +67,14 @@ else @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@ endif +.PHONY: samples +samples: + @echo "Downloading samples..." + @mkdir -p samples + @wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3 + @ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav + @rm samples/a13.mp3 + mkdir: @echo Mkdir ${BUILD_DIR} @install -d ${BUILD_DIR} diff --git a/bindings/go/README.md b/bindings/go/README.md index 9d832096512..c6de0b9790a 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -7,8 +7,12 @@ This package provides Go bindings for whisper.cpp. They have been tested on: * Fedora Linux on x86_64 The "low level" bindings are in the `bindings/go` directory and there is a more -Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage -is as follows: +Go-style package in the `bindings/go/pkg/whisper` directory. + +Legacy stateless example (single worker). For the recommended stateful API and +concurrency-safe usage, see "New high-level API" below. Note: `Model.NewContext()` +returns a stateless context for backward compatibility and is not safe for parallel +`Process` calls (may return `ErrStatelessBusy`). ```go import ( @@ -100,6 +104,123 @@ Getting help: * Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312) +## New high-level API (stateful and stateless contexts) + +The `pkg/whisper` package now exposes two context kinds: + +- StatefulContext: recommended for concurrency. Each context owns its own whisper_state. +- StatelessContext: shares the model context. Simpler, but not suitable for parallel `Process` calls. + +### Quick start: stateful context (recommended) + +```go +package main + +import ( + "fmt" + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func main() { + // Load model + model, err := whisper.NewModelContext("./models/ggml-small.en.bin") + if err != nil { + panic(err) + } + defer model.Close() + + // Configure parameters (optional: provide a config func) + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) { + p.SetThreads(4) + p.SetLanguage("en") // or "auto" + p.SetTranslate(false) + }) + if err != nil { + panic(err) + } + + // Create stateful context (safe for running in parallel goroutines) + ctx, err := whisper.NewStatefulContext(model, params) + if err != nil { + panic(err) + } + defer ctx.Close() + + // Your 16-bit mono PCM at 16kHz as float32 samples + var samples []float32 + + // Process. Callbacks are optional. + if err := ctx.Process(samples, nil, nil, nil); err != nil { + panic(err) + } + + // Read segments + for { + seg, err := ctx.NextSegment() + if err != nil { + break + } + fmt.Printf("[%v -> %v] %s\n", seg.Start, seg.End, seg.Text) + } +} +``` + +### Quick start: stateless context (single worker) + +```go +// Load model as above +model, _ := whisper.NewModelContext("./models/ggml-small.en.bin") +defer model.Close() + +params, _ := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil) +ctx, _ := whisper.NewStatelessContext(model, params) +defer ctx.Close() + +if err := ctx.Process(samples, nil, nil, nil); err != nil { panic(err) } +for { + seg, err := ctx.NextSegment() + if err != nil { break } + fmt.Println(seg.Text) +} +``` + +### Deprecations and migration notes + +- The `Context` interface setters are deprecated (SetThreads, SetLanguage, etc.). Use `Parameters` via `NewParameters` and pass it when creating a context. +- `Model.NewContext()` remains for backward compatibility and returns a stateless context by default. Prefer `NewStatefulContext` for concurrency. +- Stateless contexts share the model context. A concurrency gate prevents overlapping `Process` calls and will return `ErrStatelessBusy` if another `Process` is in flight. +- For parallel processing, create one `StatefulContext` per goroutine. + +## Benchmarks + +Benchmarks live in `pkg/whisper` and compare CPU vs GPU, stateful vs stateless, threads, and callback modes. + +### Prerequisites + +- Model: `models/ggml-small.en.bin` (or your choice). +- Sample: `samples/jfk.wav`. +- Build the C libs once (also downloads a model for examples): + +```bash +cd bindings/go +make examples +# optionally: ./build/go-model-download -out models +``` + +### Run benchmarks + +```bash +cd bindings/go/pkg/whisper +make benchmark +``` + +### What the benchmarks measure + +- Variants: device (cpu/gpu) x context kind (stateless/stateful) x threads {1,2,4, NumCPU} x callback mode (NoCallback, WithSegmentCallback). +- Standard Go benchmark outputs: ns/op, B/op, allocs/op. We also set bytes per op to sample bytes. +- Custom metric `ms_process`: wall time per `Process` iteration, reported via `b.ReportMetric`. +- When `printTimings` is enabled, model-level timings are printed for NoCallback runs using `model.PrintTimings()`. + ## License The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details. diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index 728c6df53d4..6ecd1a26840 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -18,9 +18,10 @@ import ( // CONSTANTS const ( - srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models - srcExt = ".bin" // Filename extension - bufSize = 1024 * 64 // Size of the buffer used for downloading the model + srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models + srcUrlTinydiarize = "https://huggingface.co/akashmjn/tinydiarize-whisper.cpp/resolve/main/" + srcExt = ".bin" // Filename extension + bufSize = 1024 * 64 // Size of the buffer used for downloading the model ) var ( @@ -38,6 +39,7 @@ var ( "large-v2", "large-v2-q5_0", "large-v2-q8_0", "large-v3", "large-v3-q5_0", "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", + "small.en-tdrz", } ) @@ -219,6 +221,12 @@ func URLForModel(model string) (string, error) { model += srcExt } + srcUrl := srcUrl + + if strings.Contains(model, "tdrz") { + srcUrl = srcUrlTinydiarize + } + // Parse the base URL url, err := url.Parse(srcUrl) if err != nil { diff --git a/bindings/go/go.mod b/bindings/go/go.mod index 7c92c7b4890..5cfd3268af1 100644 --- a/bindings/go/go.mod +++ b/bindings/go/go.mod @@ -3,13 +3,13 @@ module github.com/ggerganov/whisper.cpp/bindings/go go 1.23 require ( + github.com/go-audio/audio v1.0.0 github.com/go-audio/wav v1.1.0 github.com/stretchr/testify v1.9.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-audio/audio v1.0.0 // indirect github.com/go-audio/riff v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/bindings/go/params.go b/bindings/go/params.go index 95c5bfaf934..07801649300 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -47,6 +47,49 @@ func (p *Params) SetPrintTimestamps(v bool) { p.print_timestamps = toBool(v) } +// Enable extra debug information +func (p *Params) SetDebugMode(v bool) { + p.debug_mode = toBool(v) +} + +// Enable tinydiarize speaker turn detection +func (p *Params) SetDiarize(v bool) { + p.tdrz_enable = toBool(v) +} + +// Voice Activity Detection (VAD) +func (p *Params) SetVAD(v bool) { + p.vad = toBool(v) +} + +func (p *Params) SetVADModelPath(path string) { + p.vad_model_path = C.CString(path) +} + +func (p *Params) SetVADThreshold(t float32) { + p.vad_params.threshold = C.float(t) +} + +func (p *Params) SetVADMinSpeechMs(ms int) { + p.vad_params.min_speech_duration_ms = C.int(ms) +} + +func (p *Params) SetVADMinSilenceMs(ms int) { + p.vad_params.min_silence_duration_ms = C.int(ms) +} + +func (p *Params) SetVADMaxSpeechSec(s float32) { + p.vad_params.max_speech_duration_s = C.float(s) +} + +func (p *Params) SetVADSpeechPadMs(ms int) { + p.vad_params.speech_pad_ms = C.int(ms) +} + +func (p *Params) SetVADSamplesOverlap(sec float32) { + p.vad_params.samples_overlap = C.float(sec) +} + // Set language id func (p *Params) SetLanguage(lang int) error { if lang == -1 { diff --git a/bindings/go/pkg/whisper/concurrency_gate.go b/bindings/go/pkg/whisper/concurrency_gate.go new file mode 100644 index 00000000000..03469a29ce7 --- /dev/null +++ b/bindings/go/pkg/whisper/concurrency_gate.go @@ -0,0 +1,58 @@ +package whisper + +import ( + "sync" + "sync/atomic" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +// Gate provides a simple acquire/release contract per key. +// The default implementation is a single-entry lock per key (limit=1). +type Gate interface { + // Acquire returns true if the key was acquired; false if already held + Acquire(key any) bool + // Release releases the key if currently held + Release(key any) +} + +// singleFlightGate is a minimal lock with limit=1 per key +type singleFlightGate struct { + m sync.Map // key -> *int32 (0 available, 1 held) +} + +func (g *singleFlightGate) Acquire(key any) bool { + ptr, _ := g.m.LoadOrStore(key, new(int32)) + busy := ptr.(*int32) + return atomic.CompareAndSwapInt32(busy, 0, 1) +} + +func (g *singleFlightGate) Release(key any) { + if v, ok := g.m.Load(key); ok { + atomic.StoreInt32(v.(*int32), 0) + } +} + +var defaultGate Gate = &singleFlightGate{} + +// SetGate allows applications to override the default gate (e.g., for custom policies) +// Passing nil resets to the default singleFlightGate. +func SetGate(g Gate) { + if g == nil { + defaultGate = &singleFlightGate{} + return + } + defaultGate = g +} + +func gate() Gate { return defaultGate } + +// modelKey derives a stable key per underlying model context for guarding stateless ops +func modelKey(model *ModelContext) *whisper.Context { + if model == nil || model.ctxAccessor() == nil { + return nil + } + ctx, _ := model.ctxAccessor().context() + return ctx +} diff --git a/bindings/go/pkg/whisper/consts.go b/bindings/go/pkg/whisper/consts.go index 5c22dc13a31..90ca20664c2 100644 --- a/bindings/go/pkg/whisper/consts.go +++ b/bindings/go/pkg/whisper/consts.go @@ -11,11 +11,21 @@ import ( // ERRORS var ( - ErrUnableToLoadModel = errors.New("unable to load model") - ErrInternalAppError = errors.New("internal application error") + ErrUnableToLoadModel = errors.New("unable to load model") + + // Deprecated: Use ErrModelClosed instead for checking the model is closed error + ErrInternalAppError = errors.New("internal application error") + ErrProcessingFailed = errors.New("processing failed") ErrUnsupportedLanguage = errors.New("unsupported language") ErrModelNotMultilingual = errors.New("model is not multilingual") + ErrModelClosed = errors.Join(errors.New("model has been closed"), ErrInternalAppError) + ErrStatelessBusy = errors.New("stateless context is busy; concurrent processing not supported") + + // Private errors + errParametersRequired = errors.New("parameters are required") + errModelRequired = errors.New("model is required") + errUnableToCreateState = errors.New("unable to create state") ) /////////////////////////////////////////////////////////////////////////////// @@ -26,3 +36,10 @@ const SampleRate = whisper.SampleRate // SampleBits is the number of bytes per sample. const SampleBits = whisper.SampleBits + +type SamplingStrategy uint32 + +const ( + SAMPLING_GREEDY SamplingStrategy = SamplingStrategy(whisper.SAMPLING_GREEDY) + SAMPLING_BEAM_SEARCH SamplingStrategy = SamplingStrategy(whisper.SAMPLING_BEAM_SEARCH) +) diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go deleted file mode 100644 index cb3d9eb8c1c..00000000000 --- a/bindings/go/pkg/whisper/context.go +++ /dev/null @@ -1,349 +0,0 @@ -package whisper - -import ( - "fmt" - "io" - "runtime" - "strings" - "time" - - // Bindings - whisper "github.com/ggerganov/whisper.cpp/bindings/go" -) - -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -type context struct { - n int - model *model - params whisper.Params -} - -// Make sure context adheres to the interface -var _ Context = (*context)(nil) - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE - -func newContext(model *model, params whisper.Params) (Context, error) { - context := new(context) - context.model = model - context.params = params - - // Return success - return context, nil -} - -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -// Set the language to use for speech recognition. -func (context *context) SetLanguage(lang string) error { - if context.model.ctx == nil { - return ErrInternalAppError - } - if !context.model.IsMultilingual() { - return ErrModelNotMultilingual - } - - if lang == "auto" { - context.params.SetLanguage(-1) - } else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { - return ErrUnsupportedLanguage - } else if err := context.params.SetLanguage(id); err != nil { - return err - } - // Return success - return nil -} - -func (context *context) IsMultilingual() bool { - return context.model.IsMultilingual() -} - -// Get language -func (context *context) Language() string { - id := context.params.Language() - if id == -1 { - return "auto" - } - return whisper.Whisper_lang_str(context.params.Language()) -} - -func (context *context) DetectedLanguage() string { - return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id()) -} - -// Set translate flag -func (context *context) SetTranslate(v bool) { - context.params.SetTranslate(v) -} - -func (context *context) SetSplitOnWord(v bool) { - context.params.SetSplitOnWord(v) -} - -// Set number of threads to use -func (context *context) SetThreads(v uint) { - context.params.SetThreads(int(v)) -} - -// Set time offset -func (context *context) SetOffset(v time.Duration) { - context.params.SetOffset(int(v.Milliseconds())) -} - -// Set duration of audio to process -func (context *context) SetDuration(v time.Duration) { - context.params.SetDuration(int(v.Milliseconds())) -} - -// Set timestamp token probability threshold (~0.01) -func (context *context) SetTokenThreshold(t float32) { - context.params.SetTokenThreshold(t) -} - -// Set timestamp token sum probability threshold (~0.01) -func (context *context) SetTokenSumThreshold(t float32) { - context.params.SetTokenSumThreshold(t) -} - -// Set max segment length in characters -func (context *context) SetMaxSegmentLength(n uint) { - context.params.SetMaxSegmentLength(int(n)) -} - -// Set token timestamps flag -func (context *context) SetTokenTimestamps(b bool) { - context.params.SetTokenTimestamps(b) -} - -// Set max tokens per segment (0 = no limit) -func (context *context) SetMaxTokensPerSegment(n uint) { - context.params.SetMaxTokensPerSegment(int(n)) -} - -// Set audio encoder context -func (context *context) SetAudioCtx(n uint) { - context.params.SetAudioCtx(int(n)) -} - -// Set maximum number of text context tokens to store -func (context *context) SetMaxContext(n int) { - context.params.SetMaxContext(n) -} - -// Set Beam Size -func (context *context) SetBeamSize(n int) { - context.params.SetBeamSize(n) -} - -// Set Entropy threshold -func (context *context) SetEntropyThold(t float32) { - context.params.SetEntropyThold(t) -} - -// Set Temperature -func (context *context) SetTemperature(t float32) { - context.params.SetTemperature(t) -} - -// Set the fallback temperature incrementation -// Pass -1.0 to disable this feature -func (context *context) SetTemperatureFallback(t float32) { - context.params.SetTemperatureFallback(t) -} - -// Set initial prompt -func (context *context) SetInitialPrompt(prompt string) { - context.params.SetInitialPrompt(prompt) -} - -// ResetTimings resets the mode timings. Should be called before processing -func (context *context) ResetTimings() { - context.model.ctx.Whisper_reset_timings() -} - -// PrintTimings prints the model timings to stdout. -func (context *context) PrintTimings() { - context.model.ctx.Whisper_print_timings() -} - -// SystemInfo returns the system information -func (context *context) SystemInfo() string { - return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n", - context.params.Threads(), - runtime.NumCPU(), - whisper.Whisper_print_system_info(), - ) -} - -// Use mel data at offset_ms to try and auto-detect the spoken language -// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. -// Returns the probabilities of all languages. -func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { - langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads) - if err != nil { - return nil, err - } - return langProbs, nil -} - -// Process new sample data and return any errors -func (context *context) Process( - data []float32, - callEncoderBegin EncoderBeginCallback, - callNewSegment SegmentCallback, - callProgress ProgressCallback, -) error { - if context.model.ctx == nil { - return ErrInternalAppError - } - // If the callback is defined then we force on single_segment mode - if callNewSegment != nil { - context.params.SetSingleSegment(true) - } - - // We don't do parallel processing at the moment - processors := 0 - if processors > 1 { - if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin, - func(new int) { - if callNewSegment != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() - s0 := num_segments - new - for i := s0; i < num_segments; i++ { - callNewSegment(toSegment(context.model.ctx, i)) - } - } - }); err != nil { - return err - } - } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin, - func(new int) { - if callNewSegment != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() - s0 := num_segments - new - for i := s0; i < num_segments; i++ { - callNewSegment(toSegment(context.model.ctx, i)) - } - } - }, func(progress int) { - if callProgress != nil { - callProgress(progress) - } - }); err != nil { - return err - } - - // Return success - return nil -} - -// Return the next segment of tokens -func (context *context) NextSegment() (Segment, error) { - if context.model.ctx == nil { - return Segment{}, ErrInternalAppError - } - if context.n >= context.model.ctx.Whisper_full_n_segments() { - return Segment{}, io.EOF - } - - // Populate result - result := toSegment(context.model.ctx, context.n) - - // Increment the cursor - context.n++ - - // Return success - return result, nil -} - -// Test for text tokens -func (context *context) IsText(t Token) bool { - switch { - case context.IsBEG(t): - return false - case context.IsSOT(t): - return false - case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot(): - return false - case context.IsPREV(t): - return false - case context.IsSOLM(t): - return false - case context.IsNOT(t): - return false - default: - return true - } -} - -// Test for "begin" token -func (context *context) IsBEG(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg() -} - -// Test for "start of transcription" token -func (context *context) IsSOT(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot() -} - -// Test for "end of transcription" token -func (context *context) IsEOT(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot() -} - -// Test for "start of prev" token -func (context *context) IsPREV(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev() -} - -// Test for "start of lm" token -func (context *context) IsSOLM(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm() -} - -// Test for "No timestamps" token -func (context *context) IsNOT(t Token) bool { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not() -} - -// Test for token associated with a specific language -func (context *context) IsLANG(t Token, lang string) bool { - if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 { - return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id) - } else { - return false - } -} - -/////////////////////////////////////////////////////////////////////////////// -// PRIVATE METHODS - -func toSegment(ctx *whisper.Context, n int) Segment { - return Segment{ - Num: n, - Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), - Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, - End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, - Tokens: toTokens(ctx, n), - } -} - -func toTokens(ctx *whisper.Context, n int) []Token { - result := make([]Token, ctx.Whisper_full_n_tokens(n)) - for i := 0; i < len(result); i++ { - data := ctx.Whisper_full_get_token_data(n, i) - - result[i] = Token{ - Id: int(ctx.Whisper_full_get_token_id(n, i)), - Text: ctx.Whisper_full_get_token_text(n, i), - P: ctx.Whisper_full_get_token_p(n, i), - Start: time.Duration(data.T0()) * time.Millisecond * 10, - End: time.Duration(data.T1()) * time.Millisecond * 10, - } - } - return result -} diff --git a/bindings/go/pkg/whisper/context_benchmark_test.go b/bindings/go/pkg/whisper/context_benchmark_test.go new file mode 100644 index 00000000000..8cc6e5d30a6 --- /dev/null +++ b/bindings/go/pkg/whisper/context_benchmark_test.go @@ -0,0 +1,285 @@ +package whisper_test + +import ( + "fmt" + "io" + "math" + "os" + "runtime" + "testing" + "time" + + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-audio/audio" + wav "github.com/go-audio/wav" +) + +func processAndExtractSegmentsSequentially(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) { + if err := ctx.Process(samples, nil, nil, nil); err != nil { + return nil, err + } + + var segments []whisper.Segment + for { + seg, err := ctx.NextSegment() + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + segments = append(segments, seg) + } + + return segments, nil +} + +func processAndExtractSegmentsWithCallback(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) { + segments := make([]whisper.Segment, 0) + + if err := ctx.Process(samples, nil, func(seg whisper.Segment) { + segments = append(segments, seg) + }, nil); err != nil { + return nil, err + } + + return segments, nil +} + +// benchProcessVariants runs the common benchmark matrix across context kinds, +// thread sets, and callback modes, for given samples. If singleIteration is true +// it runs only one iteration regardless of b.N. If printTimings is true, +// model timings and custom ms_process metric are reported for NoCallback runs. +func benchProcessVariants( + b *testing.B, + samples []float32, + singleIteration bool, + printTimings bool, + useGPU bool, +) { + threadSets := []uint{1, 2, 4, uint(runtime.NumCPU())} + + device := "cpu" + if useGPU { + device = "gpu" + } + + // Initialize model per device mode + mp := whisper.NewModelContextParams() + mp.SetUseGPU(useGPU) + model, err := whisper.NewModelContextWithParams(ModelPath, mp) + if err != nil { + b.Fatalf("load model (%s): %v", device, err) + } + defer func() { _ = model.Close() }() + + // Context kinds: stateless and stateful + ctxKinds := []struct { + name string + new func() (whisper.Context, error) + }{ + { + name: "stateless", + new: func() (whisper.Context, error) { + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) {}) + if err != nil { + return nil, err + } + return whisper.NewStatelessContext(model, params) + }, + }, + { + name: "stateful", + new: func() (whisper.Context, error) { + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil) + if err != nil { + return nil, err + } + return whisper.NewStatefulContext(model, params) + }, + }, + } + + for _, kind := range ctxKinds { + b.Run(device+"/"+kind.name, func(b *testing.B) { + for _, threads := range threadSets { + b.Run(fmt.Sprintf("threads=%d/NoCallback", threads), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(samples) * 4)) + ctx, err := kind.new() + if err != nil { + b.Fatalf("new %s context: %v", kind.name, err) + } + defer func() { _ = ctx.Close() }() + ctx.SetThreads(threads) + + iters := b.N + if singleIteration { + iters = 1 + } + + b.ResetTimer() + for i := 0; i < iters; i++ { + model.ResetTimings() + start := time.Now() + + segments, err := processAndExtractSegmentsSequentially(ctx, samples) + if err != nil { + b.Fatalf("process and extract segments sequentially: %v", err) + } + + b.Logf("segments: %+v", segments) + + elapsed := time.Since(start) + + if printTimings { + model.PrintTimings() + } + + b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process") + } + }) + + b.Run(fmt.Sprintf("threads=%d/WithSegmentCallback", threads), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(samples) * 4)) + ctx, err := kind.new() + if err != nil { + b.Fatalf("new %s context: %v", kind.name, err) + } + defer func() { _ = ctx.Close() }() + ctx.SetThreads(threads) + + iters := b.N + if singleIteration { + iters = 1 + } + + b.ResetTimer() + for i := 0; i < iters; i++ { + start := time.Now() + model.ResetTimings() + + // Passing a segment callback forces single-segment mode and exercises token extraction + segments, err := processAndExtractSegmentsWithCallback(ctx, samples) + if err != nil { + b.Fatalf("process with callback: %v", err) + } + + b.Logf("segments: %+v", segments) + + elapsed := time.Since(start) + if printTimings { + model.PrintTimings() + } + + b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process") + } + }) + } + }) + } +} + +// BenchmarkContextProcess runs the high-level Context.Process across +// different thread counts, with and without segment callbacks. +func BenchmarkContextProcessCPU(b *testing.B) { + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + b.Skipf("model not found: %s", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + b.Skipf("sample not found: %s", SamplePath) + } + + // Load audio once (reuse helper) + data := helperLoadSample(b, SamplePath) + + benchProcessVariants(b, data, false, true, false) +} + +// BenchmarkContextProcessBig runs one single iteration over a big input +// (the short sample concatenated 10x) to simulate long audio processing. +// This is complementary to BenchmarkContextProcess which runs many iterations +// over the short sample. +func BenchmarkContextProcessBigCPU(b *testing.B) { + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + b.Skipf("model not found: %s", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + b.Skipf("sample not found: %s", SamplePath) + } + + // Load audio once (reuse helper with meta) + data, sampleRate, numChans := helperLoadSampleWithMeta(b, SamplePath) + + // Build big dataset: input concatenated 10x + bigData := make([]float32, len(data)*10) + for i := 0; i < 10; i++ { + copy(bigData[i*len(data):(i+1)*len(data)], data) + } + + // Write the big dataset to a wav file for inspection + outPath := "../../samples/benchmark_out.wav" + fout, err := os.Create(outPath) + if err != nil { + b.Fatalf("create output wav: %v", err) + } + enc := wav.NewEncoder(fout, sampleRate, 16, numChans, 1) + intBuf := &audio.IntBuffer{ + Format: &audio.Format{NumChannels: numChans, SampleRate: sampleRate}, + SourceBitDepth: 16, + Data: make([]int, len(bigData)), + } + for i, s := range bigData { + v := int(math.Round(float64(s) * 32767.0)) + if v > 32767 { + v = 32767 + } else if v < -32768 { + v = -32768 + } + intBuf.Data[i] = v + } + if err := enc.Write(intBuf); err != nil { + _ = fout.Close() + b.Fatalf("encode wav: %v", err) + } + if err := enc.Close(); err != nil { + _ = fout.Close() + b.Fatalf("close encoder: %v", err) + } + _ = fout.Close() + + benchProcessVariants(b, bigData, true, true, false) +} + +// GPU variants reuse model-level GPU enablement via model params +func BenchmarkContextProcessGPU(b *testing.B) { + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + b.Skipf("model not found: %s", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + b.Skipf("sample not found: %s", SamplePath) + } + + data := helperLoadSample(b, SamplePath) + + benchProcessVariants(b, data, false, true, true) +} + +func BenchmarkContextProcessBigGPU(b *testing.B) { + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + b.Skipf("model not found: %s", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + b.Skipf("sample not found: %s", SamplePath) + } + + data, _, _ := helperLoadSampleWithMeta(b, SamplePath) + + bigData := make([]float32, len(data)*10) + for i := 0; i < 10; i++ { + copy(bigData[i*len(data):(i+1)*len(data)], data) + } + + benchProcessVariants(b, bigData, true, true, true) +} diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index e98a4c2b80b..f18238ba54b 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -1,124 +1,583 @@ package whisper_test import ( + "io" "os" "testing" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-audio/wav" assert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSetLanguage(t *testing.T) { assert := assert.New(t) - model, err := whisper.New(ModelPath) - assert.NoError(err) - assert.NotNil(model) - defer model.Close() + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } - context, err := model.NewContext() - assert.NoError(err) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() - // This returns an error since - // the model 'models/ggml-small.en.bin' - // that is loaded is not multilingual - err = context.SetLanguage("en") - assert.Error(err) + // This returns an error since the small.en model is not multilingual + err := ctx.SetLanguage("en") + assert.Error(err) + }) + } } func TestContextModelIsMultilingual(t *testing.T) { assert := assert.New(t) - model, err := whisper.New(ModelPath) - assert.NoError(err) - assert.NotNil(model) - defer model.Close() + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } - context, err := model.NewContext() - assert.NoError(err) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + assert.False(ctx.IsMultilingual()) + }) + } +} - isMultilingual := context.IsMultilingual() +func TestLanguage(t *testing.T) { + assert := assert.New(t) - // This returns false since - // the model 'models/ggml-small.en.bin' - // that is loaded is not multilingual - assert.False(isMultilingual) + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + expectedLanguage := "en" + actualLanguage := ctx.Language() + assert.Equal(expectedLanguage, actualLanguage) + }) + } } -func TestLanguage(t *testing.T) { +// Generic behavior: Language() and DetectedLanguage() match for both context types +func TestContext_Generic_LanguageAndDetectedLanguage(t *testing.T) { assert := assert.New(t) - model, err := whisper.New(ModelPath) - assert.NoError(err) - assert.NotNil(model) - defer model.Close() + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } - context, err := model.NewContext() - assert.NoError(err) + data := helperLoadSample(t, SamplePath) + + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() - // This always returns en since - // the model 'models/ggml-small.en.bin' - // that is loaded is not multilingual - expectedLanguage := "en" - actualLanguage := context.Language() - assert.Equal(expectedLanguage, actualLanguage) + langBefore := ctx.Language() + assert.NoError(ctx.Process(data, nil, nil, nil)) + detected := ctx.DetectedLanguage() + assert.Equal(langBefore, detected) + }) + } } func TestProcess(t *testing.T) { assert := assert.New(t) - fh, err := os.Open(SamplePath) - assert.NoError(err) - defer fh.Close() + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } - // Decode the WAV file - load the full buffer - dec := wav.NewDecoder(fh) - buf, err := dec.FullPCMBuffer() - assert.NoError(err) - assert.Equal(uint16(1), dec.NumChans) + data := helperLoadSample(t, SamplePath) - data := buf.AsFloat32Buffer().Data + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } - model, err := whisper.New(ModelPath) - assert.NoError(err) - assert.NotNil(model) - defer model.Close() + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + err := ctx.Process(data, nil, nil, nil) + assert.NoError(err) + }) + } +} - context, err := model.NewContext() - assert.NoError(err) +func TestDetectedLanguage(t *testing.T) { + assert := assert.New(t) - err = context.Process(data, nil, nil, nil) - assert.NoError(err) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + err := ctx.Process(data, nil, nil, nil) + assert.NoError(err) + expectedLanguage := "en" + actualLanguage := ctx.DetectedLanguage() + assert.Equal(expectedLanguage, actualLanguage) + }) + } } -func TestDetectedLanguage(t *testing.T) { +// TestContext_ConcurrentProcessing tests that multiple contexts can process concurrently +// without interfering with each other (validates the whisper_state isolation fix) +func TestContext_ConcurrentProcessing(t *testing.T) { assert := assert.New(t) - fh, err := os.Open(SamplePath) - assert.NoError(err) - defer fh.Close() + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } - // Decode the WAV file - load the full buffer - dec := wav.NewDecoder(fh) - buf, err := dec.FullPCMBuffer() - assert.NoError(err) - assert.Equal(uint16(1), dec.NumChans) + data := helperLoadSample(t, SamplePath) + + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + + err := ctx.Process(data, nil, nil, nil) + assert.NoError(err) + + seg, err := ctx.NextSegment() + assert.NoError(err) + assert.NotEmpty(seg.Text) + }) + } +} + +// TestContext_Close tests that Context.Close() properly frees resources +// and allows context to be used even after it has been closed +func TestContext_Close(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + + // Close the context + err := ctx.Close() + require.NoError(t, err) + + // Try to use closed context - should return errors + err = ctx.Process([]float32{0.1, 0.2, 0.3}, nil, nil, nil) + require.ErrorIs(t, err, whisper.ErrModelClosed) + // TODO: remove this logic after deprecating the ErrInternalAppError + require.ErrorIs(t, err, whisper.ErrInternalAppError) + + lang := ctx.DetectedLanguage() + require.Empty(t, lang) + + _, err = ctx.NextSegment() + assert.ErrorIs(err, whisper.ErrModelClosed) + // TODO: remove this logic after deprecating the ErrInternalAppError + assert.ErrorIs(err, whisper.ErrInternalAppError) + + // Multiple closes should be safe + err = ctx.Close() + require.NoError(t, err) + }) + } +} + +func Test_Close_Context_of_Closed_Model(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + t.Run("stateless", func(t *testing.T) { + model, err := whisper.NewModelContext(ModelPath) + assert.NoError(err) + defer func() { _ = model.Close() }() + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatelessContext(model, params) + assert.NoError(err) + require.NoError(t, model.Close()) + require.NoError(t, ctx.Close()) + }) + + t.Run("stateful", func(t *testing.T) { + model, err := whisper.NewModelContext(ModelPath) + assert.NoError(err) + defer func() { _ = model.Close() }() + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatefulContext(model, params) + assert.NoError(err) + require.NoError(t, model.Close()) + require.NoError(t, ctx.Close()) + }) +} + +func TestContext_VAD_And_Diarization_Params_DoNotPanic(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } - data := buf.AsFloat32Buffer().Data + data := helperLoadSample(t, SamplePath) - model, err := whisper.New(ModelPath) + model, err := whisper.NewModelContext(ModelPath) assert.NoError(err) - assert.NotNil(model) - defer model.Close() + defer func() { _ = model.Close() }() - context, err := model.NewContext() + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil) assert.NoError(err) + assert.NotNil(params) - err = context.Process(data, nil, nil, nil) + ctx, err := whisper.NewStatefulContext(model, params) assert.NoError(err) + defer func() { _ = ctx.Close() }() + + p := ctx.Params() + p.SetDiarize(true) + p.SetVAD(true) + p.SetVADThreshold(0.5) + p.SetVADMinSpeechMs(200) + p.SetVADMinSilenceMs(100) + p.SetVADMaxSpeechSec(10) + p.SetVADSpeechPadMs(30) + p.SetVADSamplesOverlap(0.02) + + err = ctx.Process(data, nil, nil, nil) + assert.NoError(err) +} + +func TestDiarization_TwoSpeakers_Boundaries(t *testing.T) { + data := helperLoadSample(t, MultiSpeakerSamplePath) + + model, err := whisper.NewModelContext(ModelTinydiarizePath) + require.NoError(t, err) + defer func() { _ = model.Close() }() + + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) { + p.SetDiarize(true) + p.SetVAD(false) + p.SetSplitOnWord(true) + p.SetMaxSegmentLength(1) + p.SetMaxTokensPerSegment(64) + p.SetTokenTimestamps(true) + }) + require.NoError(t, err) + + // diarize ON with beam search and tighter segmentation + ctxOn, err := whisper.NewStatefulContext(model, params) + require.NoError(t, err) + defer func() { _ = ctxOn.Close() }() + + require.NoError(t, ctxOn.Process(data, nil, nil, nil)) + var turnsOn int + for { + seg, err := ctxOn.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + if seg.SpeakerTurnNext { + turnsOn++ + } + } + require.Greater(t, turnsOn, 0, "expected speaker turn boundaries with diarization enabled") + + // diarize OFF baseline with same segmentation and beam + ctxOff, err := whisper.NewStatefulContext(model, params) + require.NoError(t, err) + defer func() { _ = ctxOff.Close() }() + + require.NoError(t, ctxOff.Process(data, nil, nil, nil)) + var turnsOff int + for { + seg, err := ctxOff.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + if seg.SpeakerTurnNext { + turnsOff++ + } + } + + require.GreaterOrEqual(t, turnsOn, turnsOff, "diarization should not reduce turn boundaries") +} + +func TestContext_SpeakerTurnNext_Field_Present(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + + cases := []struct { + name string + new func(t *testing.T) (whisper.Context, func()) + }{ + {name: "stateless", new: helperNewStatelessContext}, + {name: "stateful", new: helperNewStatefulContext}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx, cleanup := tc.new(t) + defer cleanup() + + err := ctx.Process(data, nil, nil, nil) + assert.NoError(err) + + seg, err := ctx.NextSegment() + assert.NoError(err) + t.Logf("SpeakerTurnNext: %v", seg.SpeakerTurnNext) + _ = seg.SpeakerTurnNext + }) + } +} + +// Ensure Process produces at least one segment for both stateless and stateful contexts +func TestContext_Process_ProducesSegments_BothKinds(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + + // Stateless + stateless, cleanupS := helperNewStatelessContext(t) + defer cleanupS() + require.NoError(t, stateless.Process(data, nil, nil, nil)) + var statelessCount int + for { + _, err := stateless.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + statelessCount++ + } + assert.Greater(statelessCount, 0, "stateless should produce at least one segment") + + // Stateful + stateful, cleanupSt := helperNewStatefulContext(t) + defer cleanupSt() + require.NoError(t, stateful.Process(data, nil, nil, nil)) + var statefulCount int + for { + _, err := stateful.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + statefulCount++ + } + assert.Greater(statefulCount, 0, "stateful should produce at least one segment") +} + +// With temperature=0 (greedy), stateless and stateful should produce identical segments +func TestContext_Process_SameResults_TemperatureZero(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + + // Use a single model to avoid environment differences + model, err := whisper.NewModelContext(ModelPath) + require.NoError(t, err) + defer func() { _ = model.Close() }() + + // Independent params with temperature=0 for determinism + p := helperNewParams(t, model, func(p *whisper.Parameters) { + p.SetTemperature(0) + p.SetThreads(1) + }) + + stateless, err := whisper.NewStatelessContext(model, p) + require.NoError(t, err) + defer func() { _ = stateless.Close() }() + + stateful, err := whisper.NewStatefulContext(model, p) + require.NoError(t, err) + defer func() { _ = stateful.Close() }() + + require.NoError(t, stateless.Process(data, nil, nil, nil)) + require.NoError(t, stateful.Process(data, nil, nil, nil)) + + // Collect segment texts + var segsStateless, segsStateful []string + for { + seg, err := stateless.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + segsStateless = append(segsStateless, seg.Text) + } + for { + seg, err := stateful.NextSegment() + if err == io.EOF { + break + } + require.NoError(t, err) + segsStateful = append(segsStateful, seg.Text) + } + + // Both should have at least one segment and be identical + require.Greater(t, len(segsStateless), 0) + require.Greater(t, len(segsStateful), 0) + assert.Equal(len(segsStateful), len(segsStateless)) + for i := range segsStateless { + assert.Equal(segsStateless[i], segsStateful[i], "segment %d text differs", i) + } +} + +// Model.GetTimings: stateless processing updates model timings (non-zero), +// stateful processing does not (zero timings) +func TestModel_GetTimings_Stateless_NonZero_Stateful_Zero(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + + model, err := whisper.NewModelContext(ModelPath) + require.NoError(t, err) + defer func() { _ = model.Close() }() + + // Stateless should produce non-zero timings + t.Run("stateless", func(t *testing.T) { + model.ResetTimings() + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatelessContext(model, params) + require.NoError(t, err) + defer func() { _ = ctx.Close() }() + + require.NoError(t, ctx.Process(data, nil, nil, nil)) + + timings, ok := model.GetTimings() + require.True(t, ok, "expected timings to be available after stateless processing") + nonZero := timings.SampleMS > 0 || timings.EncodeMS > 0 || timings.DecodeMS > 0 || timings.BatchdMS > 0 || timings.PromptMS > 0 + assert.True(nonZero, "expected at least one non-zero timing after stateless processing: %#v", timings) + }) + + // Stateful should keep model-level timings at zero + t.Run("stateful", func(t *testing.T) { + model.ResetTimings() + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatefulContext(model, params) + require.NoError(t, err) + defer func() { _ = ctx.Close() }() + + require.NoError(t, ctx.Process(data, nil, nil, nil)) - expectedLanguage := "en" - actualLanguage := context.DetectedLanguage() - assert.Equal(expectedLanguage, actualLanguage) + timings, ok := model.GetTimings() + // Expect timings present but all zero; if not present at all, treat as zero-equivalent + if ok { + assert.Equal(float32(0), timings.SampleMS) + assert.Equal(float32(0), timings.EncodeMS) + assert.Equal(float32(0), timings.DecodeMS) + assert.Equal(float32(0), timings.BatchdMS) + assert.Equal(float32(0), timings.PromptMS) + } else { + t.Log("timings not available for stateful processing; treating as zero") + } + }) } diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index e3122c44b76..eabdb2db097 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -1,6 +1,7 @@ package whisper import ( + "fmt" "io" "time" ) @@ -20,15 +21,21 @@ type ProgressCallback func(int) // continue processing. It is called during the Process function type EncoderBeginCallback func() bool +type ParamsConfigure func(*Parameters) + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) +// Deprecated: Use NewModel implementation struct instead of relying on this interface type Model interface { io.Closer // Return a new speech-to-text context. + // It may return an error is the model is not loaded or closed + // Deprecated: Use NewContext implementation struct instead of relying on this interface NewContext() (Context, error) // Return true if the model is multilingual. + // It returns false if the model is not loaded or closed IsMultilingual() bool // Return all languages supported. @@ -36,29 +43,73 @@ type Model interface { } // Context is the speech recognition context. +// Deprecated: Use NewContext implementation struct instead of relying on this interface type Context interface { - SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. - SetTranslate(bool) // Set translate flag - IsMultilingual() bool // Return true if the model is multilingual. - Language() string // Get language - DetectedLanguage() string // Get detected language - - SetOffset(time.Duration) // Set offset - SetDuration(time.Duration) // Set duration - SetThreads(uint) // Set number of threads to use - SetSplitOnWord(bool) // Set split on word flag - SetTokenThreshold(float32) // Set timestamp token probability threshold - SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold - SetMaxSegmentLength(uint) // Set max segment length in characters - SetTokenTimestamps(bool) // Set token timestamps flag - SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) - SetAudioCtx(uint) // Set audio encoder context - SetMaxContext(n int) // Set maximum number of text context tokens to store - SetBeamSize(n int) // Set Beam Size - SetEntropyThold(t float32) // Set Entropy threshold - SetInitialPrompt(prompt string) // Set initial prompt - SetTemperature(t float32) // Set temperature - SetTemperatureFallback(t float32) // Set temperature incrementation + io.Closer + + // Deprecated: Use Params().SetLanguage() instead + SetLanguage(string) error + + // Deprecated: Use Params().SetTranslate() instead + SetTranslate(bool) + + // Deprecated: Use Params().SetSplitOnWord() instead + SetSplitOnWord(bool) + + // Deprecated: Use Params().SetThreads() instead + SetThreads(uint) + + // Deprecated: Use Params().SetOffset() instead + SetOffset(time.Duration) + + // Deprecated: Use Params().SetDuration() instead + SetDuration(time.Duration) + + // Deprecated: Use Params().SetTokenThreshold() instead + SetTokenThreshold(float32) + + // Deprecated: Use Params().SetTokenSumThreshold() instead + SetTokenSumThreshold(float32) + // Deprecated: Use Params().SetMaxSegmentLength() instead + + SetMaxSegmentLength(uint) + + // Deprecated: Use Params().SetTokenTimestamps() instead + SetTokenTimestamps(bool) + + // Deprecated: Use Params().SetMaxTokensPerSegment() instead + SetMaxTokensPerSegment(uint) + + // Deprecated: Use Params().SetAudioCtx() instead + SetAudioCtx(uint) + + // Deprecated: Use Params().SetMaxContext() instead + SetMaxContext(int) + + // Deprecated: Use Params().SetBeamSize() instead + SetBeamSize(int) + + // Deprecated: Use Params().SetEntropyThold() instead + SetEntropyThold(float32) + + // Deprecated: Use Params().SetTemperature() instead + SetTemperature(float32) + + // Deprecated: Use Params().SetTemperatureFallback() instead + SetTemperatureFallback(float32) + + // Deprecated: Use Params().SetInitialPrompt() instead + SetInitialPrompt(string) + + // Get language of the context parameters + // Deprecated: Use Params().Language() instead + Language() string + + // Deprecated: Use Model().IsMultilingual() instead + IsMultilingual() bool + + // Get detected language + DetectedLanguage() string // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the @@ -69,19 +120,39 @@ type Context interface { // is reached, when io.EOF is returned. NextSegment() (Segment, error) - IsBEG(Token) bool // Test for "begin" token - IsSOT(Token) bool // Test for "start of transcription" token - IsEOT(Token) bool // Test for "end of transcription" token - IsPREV(Token) bool // Test for "start of prev" token - IsSOLM(Token) bool // Test for "start of lm" token - IsNOT(Token) bool // Test for "No timestamps" token - IsLANG(Token, string) bool // Test for token associated with a specific language - IsText(Token) bool // Test for text token + // Deprecated: Use Model().TokenIdentifier().IsBEG() instead + IsBEG(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsSOT() instead + IsSOT(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsEOT() instead + IsEOT(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsPREV() instead + IsPREV(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsSOLM() instead + IsSOLM(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsNOT() instead + IsNOT(Token) bool + + // Deprecated: Use Model().TokenIdentifier().IsLANG() instead + IsLANG(Token, string) bool - // Timings + // Deprecated: Use Model().TokenIdentifier().IsText() instead + IsText(Token) bool + + // Deprecated: Use Model().PrintTimings() instead + // these are model-level performance metrics PrintTimings() + + // Deprecated: Use Model().ResetTimings() instead + // these are model-level performance metrics ResetTimings() + // SystemInfo returns the system information SystemInfo() string } @@ -98,12 +169,29 @@ type Segment struct { // The tokens of the segment. Tokens []Token + + // True if the next segment is predicted as a speaker turn (tinydiarize) + // It works only with the diarization supporting models (like small.en-tdrz.bin) with the diarization enabled + // using Parameters.SetDiarize(true) + SpeakerTurnNext bool +} + +func (s Segment) String() string { + // foramt: [00:01:39.000 --> 00:01:50.000] And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + return fmt.Sprintf("[%s --> %s] %s", s.Start.Truncate(time.Millisecond), s.End.Truncate(time.Millisecond), s.Text) } // Token is a text or special token type Token struct { - Id int - Text string - P float32 + // ID of the token + Id int + + // Text of the token + Text string + + // Probability of the token + P float32 + + // Timestamp of the token Start, End time.Duration } diff --git a/bindings/go/pkg/whisper/log.go b/bindings/go/pkg/whisper/log.go new file mode 100644 index 00000000000..66eb0d5c78c --- /dev/null +++ b/bindings/go/pkg/whisper/log.go @@ -0,0 +1,9 @@ +package whisper + +import low "github.com/ggerganov/whisper.cpp/bindings/go" + +// DisableLogs disables all C-side logging from whisper.cpp and ggml. +// Call once early in your program before creating models/contexts. +func DisableLogs() { + low.DisableLogs() +} diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a150223c7..c22490fb84c 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -3,99 +3,177 @@ package whisper import ( "fmt" "os" - "runtime" // Bindings - whisper "github.com/ggerganov/whisper.cpp/bindings/go" + low "github.com/ggerganov/whisper.cpp/bindings/go" ) -/////////////////////////////////////////////////////////////////////////////// -// TYPES - -type model struct { - path string - ctx *whisper.Context +type ModelContext struct { + path string + ca *ctxAccessor + tokId *tokenIdentifier } // Make sure model adheres to the interface -var _ Model = (*model)(nil) - -/////////////////////////////////////////////////////////////////////////////// -// LIFECYCLE +var _ Model = (*ModelContext)(nil) + +// Timings is a compact, high-level timing snapshot in milliseconds +type Timings struct { + SampleMS float32 + EncodeMS float32 + DecodeMS float32 + BatchdMS float32 + PromptMS float32 +} +// Deprecated: Use NewModelContext instead func New(path string) (Model, error) { - model := new(model) + return NewModelContext(path) +} + +// NewModelContext creates a new model context + +func NewModelContext( + path string, +) (*ModelContext, error) { + return NewModelContextWithParams( + path, + NewModelContextParams(), + ) +} + +// NewModelContextWithParams creates a new model context with custom initialization params +func NewModelContextWithParams( + path string, + params ModelContextParams, +) (*ModelContext, error) { + model := new(ModelContext) if _, err := os.Stat(path); err != nil { return nil, err - } else if ctx := whisper.Whisper_init(path); ctx == nil { + } + + ctx := low.Whisper_init_with_params(path, params.toLow()) + if ctx == nil { return nil, ErrUnableToLoadModel - } else { - model.ctx = ctx - model.path = path } - // Return success + model.ca = newCtxAccessor(ctx) + model.tokId = newTokenIdentifier(model.ca) + model.path = path + return model, nil } -func (model *model) Close() error { - if model.ctx != nil { - model.ctx.Whisper_free() - } - - // Release resources - model.ctx = nil - - // Return success - return nil +func (model *ModelContext) Close() error { + return model.ca.close() } -/////////////////////////////////////////////////////////////////////////////// -// STRINGIFY +func (model *ModelContext) ctxAccessor() *ctxAccessor { + return model.ca +} -func (model *model) String() string { +func (model *ModelContext) String() string { str := "" } -/////////////////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - // Return true if model is multilingual (language and translation options are supported) -func (model *model) IsMultilingual() bool { - return model.ctx.Whisper_is_multilingual() != 0 +func (model *ModelContext) IsMultilingual() bool { + ctx, err := model.ca.context() + if err != nil { + return false + } + + return ctx.Whisper_is_multilingual() != 0 } // Return all recognized languages. Initially it is set to auto-detect -func (model *model) Languages() []string { - result := make([]string, 0, whisper.Whisper_lang_max_id()) - for i := 0; i < whisper.Whisper_lang_max_id(); i++ { - str := whisper.Whisper_lang_str(i) - if model.ctx.Whisper_lang_id(str) >= 0 { +func (model *ModelContext) Languages() []string { + ctx, err := model.ca.context() + if err != nil { + return nil + } + + result := make([]string, 0, low.Whisper_lang_max_id()) + for i := 0; i < low.Whisper_lang_max_id(); i++ { + str := low.Whisper_lang_str(i) + if ctx.Whisper_lang_id(str) >= 0 { result = append(result, str) } } + return result } -func (model *model) NewContext() (Context, error) { - if model.ctx == nil { - return nil, ErrInternalAppError +// NewContext creates a new speech-to-text context. +// Each context is backed by an isolated whisper_state for safe concurrent processing. +func (model *ModelContext) NewContext() (Context, error) { + // Create new context with default params + params, err := NewParameters(model, SAMPLING_GREEDY, nil) + if err != nil { + return nil, err + } + + // Return new context (stateless for backward compatibility with timings) + return NewStatelessContext( + model, + params, + ) +} + +// PrintTimings prints the model performance timings to stdout. +func (model *ModelContext) PrintTimings() { + ctx, err := model.ca.context() + if err != nil { + return + } + + ctx.Whisper_print_timings() +} + +// ResetTimings resets the model performance timing counters. +func (model *ModelContext) ResetTimings() { + ctx, err := model.ca.context() + if err != nil { + return } - // Create new context - params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) - params.SetTranslate(false) - params.SetPrintSpecial(false) - params.SetPrintProgress(false) - params.SetPrintRealtime(false) - params.SetPrintTimestamps(false) - params.SetThreads(runtime.NumCPU()) - params.SetNoContext(true) - - // Return new context - return newContext(model, params) + ctx.Whisper_reset_timings() +} + +// GetTimings returns a compact snapshot of model-level processing timings. +// +// Behavior notes: +// - Stateless contexts (created via ModelContext.NewContext or NewStatelessContext) +// update model-level timings during Process. After a stateless Process call, +// the returned timings are expected to be non-zero (ok == true). +// - Stateful contexts (created via NewStatefulContext) use a per-state backend +// and do not affect model-level timings. After a stateful Process call, +// the returned timings are expected to be zero values (fields equal 0) or +// the call may return ok == false depending on the underlying implementation. +// +// Use ResetTimings before measurement to clear previous values. +func (model *ModelContext) GetTimings() (Timings, bool) { + ctx, err := model.ca.context() + if err != nil { + return Timings{}, false + } + if t, ok := ctx.Whisper_get_timings_go(); ok { + return Timings{ + SampleMS: t.SampleMS, + EncodeMS: t.EncodeMS, + DecodeMS: t.DecodeMS, + BatchdMS: t.BatchdMS, + PromptMS: t.PromptMS, + }, true + } + return Timings{}, false +} + +func (model *ModelContext) tokenIdentifier() *tokenIdentifier { + return model.tokId } diff --git a/bindings/go/pkg/whisper/model_context_params.go b/bindings/go/pkg/whisper/model_context_params.go new file mode 100644 index 00000000000..62733b9bbd1 --- /dev/null +++ b/bindings/go/pkg/whisper/model_context_params.go @@ -0,0 +1,27 @@ +package whisper + +import ( + low "github.com/ggerganov/whisper.cpp/bindings/go" +) + +type ModelContextParams struct { + p low.ContextParams +} + +func NewModelContextParams() ModelContextParams { + return ModelContextParams{ + p: low.Whisper_context_default_params(), + } +} + +func (p *ModelContextParams) SetUseGPU(v bool) { + p.p.SetUseGPU(v) +} + +func (p *ModelContextParams) SetGPUDevice(n int) { + p.p.SetGPUDevice(n) +} + +func (p *ModelContextParams) toLow() low.ContextParams { + return p.p +} diff --git a/bindings/go/pkg/whisper/model_test.go b/bindings/go/pkg/whisper/model_test.go index 8797f0d0fd0..bd152088060 100644 --- a/bindings/go/pkg/whisper/model_test.go +++ b/bindings/go/pkg/whisper/model_test.go @@ -13,7 +13,7 @@ func TestNew(t *testing.T) { model, err := whisper.New(ModelPath) assert.NoError(err) assert.NotNil(model) - defer model.Close() + defer func() { _ = model.Close() }() }) @@ -42,20 +42,34 @@ func TestNewContext(t *testing.T) { model, err := whisper.New(ModelPath) assert.NoError(err) assert.NotNil(model) - defer model.Close() + defer func() { _ = model.Close() }() context, err := model.NewContext() assert.NoError(err) assert.NotNil(context) } +func TestNewContext_ClosedModel(t *testing.T) { + assert := assert.New(t) + + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + assert.NoError(model.Close()) + + context, err := model.NewContext() + assert.ErrorIs(err, whisper.ErrInternalAppError) + assert.ErrorIs(err, whisper.ErrModelClosed) + assert.Nil(context) +} + func TestIsMultilingual(t *testing.T) { assert := assert.New(t) model, err := whisper.New(ModelPath) assert.NoError(err) assert.NotNil(model) - defer model.Close() + defer func() { _ = model.Close() }() isMultilingual := model.IsMultilingual() @@ -71,7 +85,7 @@ func TestLanguages(t *testing.T) { model, err := whisper.New(ModelPath) assert.NoError(err) assert.NotNil(model) - defer model.Close() + defer func() { _ = model.Close() }() expectedLanguages := []string{ "en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", diff --git a/bindings/go/pkg/whisper/params_wrap.go b/bindings/go/pkg/whisper/params_wrap.go new file mode 100644 index 00000000000..44b74562b67 --- /dev/null +++ b/bindings/go/pkg/whisper/params_wrap.go @@ -0,0 +1,120 @@ +package whisper + +import ( + "runtime" + "time" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +// Parameters is a high-level wrapper that implements the Parameters interface +// and delegates to the underlying low-level whisper.Params. +type Parameters struct { + p *whisper.Params +} + +func defaultParamsConfigure(params *Parameters) { + params.SetTranslate(false) + params.SetPrintSpecial(false) + params.SetPrintProgress(false) + params.SetPrintRealtime(false) + params.SetPrintTimestamps(false) + // Default behavior backward compatibility + params.SetThreads(uint(runtime.NumCPU())) + params.SetNoContext(true) +} + +func NewParameters( + model *ModelContext, + sampling SamplingStrategy, + configure ParamsConfigure, +) (*Parameters, error) { + ctx, err := model.ca.context() + if err != nil { + return nil, ErrModelClosed + } + + p := ctx.Whisper_full_default_params(whisper.SamplingStrategy(sampling)) + safeParams := &Parameters{ + p: &p, + } + + defaultParamsConfigure(safeParams) + + if configure != nil { + configure(safeParams) + } + + return safeParams, nil +} + +func (w *Parameters) SetTranslate(v bool) { w.p.SetTranslate(v) } +func (w *Parameters) SetSplitOnWord(v bool) { w.p.SetSplitOnWord(v) } +func (w *Parameters) SetThreads(v uint) { w.p.SetThreads(int(v)) } +func (w *Parameters) SetOffset(d time.Duration) { w.p.SetOffset(int(d.Milliseconds())) } +func (w *Parameters) SetDuration(d time.Duration) { w.p.SetDuration(int(d.Milliseconds())) } +func (w *Parameters) SetTokenThreshold(t float32) { w.p.SetTokenThreshold(t) } +func (w *Parameters) SetTokenSumThreshold(t float32) { w.p.SetTokenSumThreshold(t) } +func (w *Parameters) SetMaxSegmentLength(n uint) { w.p.SetMaxSegmentLength(int(n)) } +func (w *Parameters) SetTokenTimestamps(b bool) { w.p.SetTokenTimestamps(b) } +func (w *Parameters) SetMaxTokensPerSegment(n uint) { w.p.SetMaxTokensPerSegment(int(n)) } +func (w *Parameters) SetAudioCtx(n uint) { w.p.SetAudioCtx(int(n)) } +func (w *Parameters) SetMaxContext(n int) { w.p.SetMaxContext(n) } +func (w *Parameters) SetBeamSize(n int) { w.p.SetBeamSize(n) } +func (w *Parameters) SetEntropyThold(t float32) { w.p.SetEntropyThold(t) } +func (w *Parameters) SetInitialPrompt(prompt string) { w.p.SetInitialPrompt(prompt) } +func (w *Parameters) SetTemperature(t float32) { w.p.SetTemperature(t) } +func (w *Parameters) SetTemperatureFallback(t float32) { w.p.SetTemperatureFallback(t) } +func (w *Parameters) SetNoContext(v bool) { w.p.SetNoContext(v) } +func (w *Parameters) SetPrintSpecial(v bool) { w.p.SetPrintSpecial(v) } +func (w *Parameters) SetPrintProgress(v bool) { w.p.SetPrintProgress(v) } +func (w *Parameters) SetPrintRealtime(v bool) { w.p.SetPrintRealtime(v) } +func (w *Parameters) SetPrintTimestamps(v bool) { w.p.SetPrintTimestamps(v) } +func (w *Parameters) SetDebugMode(v bool) { w.p.SetDebugMode(v) } + +// Diarization (tinydiarize) +func (w *Parameters) SetDiarize(v bool) { w.p.SetDiarize(v) } + +// Voice Activity Detection (VAD) +func (w *Parameters) SetVAD(v bool) { w.p.SetVAD(v) } +func (w *Parameters) SetVADModelPath(p string) { w.p.SetVADModelPath(p) } +func (w *Parameters) SetVADThreshold(t float32) { w.p.SetVADThreshold(t) } +func (w *Parameters) SetVADMinSpeechMs(ms int) { w.p.SetVADMinSpeechMs(ms) } +func (w *Parameters) SetVADMinSilenceMs(ms int) { w.p.SetVADMinSilenceMs(ms) } +func (w *Parameters) SetVADMaxSpeechSec(s float32) { w.p.SetVADMaxSpeechSec(s) } +func (w *Parameters) SetVADSpeechPadMs(ms int) { w.p.SetVADSpeechPadMs(ms) } +func (w *Parameters) SetVADSamplesOverlap(sec float32) { w.p.SetVADSamplesOverlap(sec) } + +func (w *Parameters) SetLanguage(lang string) error { + if lang == "auto" { + return w.p.SetLanguage(-1) + } + id := whisper.Whisper_lang_id_str(lang) + if id < 0 { + return ErrUnsupportedLanguage + } + return w.p.SetLanguage(id) +} + +func (w *Parameters) SetSingleSegment(v bool) { + w.p.SetSingleSegment(v) +} + +// Getter methods for Parameters interface +func (w *Parameters) Language() string { + id := w.p.Language() + if id == -1 { + return "auto" + } + + return whisper.Whisper_lang_str(id) +} + +func (w *Parameters) Threads() int { + return w.p.Threads() +} + +func (w *Parameters) unsafeParams() (*whisper.Params, error) { + return w.p, nil +} diff --git a/bindings/go/pkg/whisper/stateful_context.go b/bindings/go/pkg/whisper/stateful_context.go new file mode 100644 index 00000000000..08e04094c09 --- /dev/null +++ b/bindings/go/pkg/whisper/stateful_context.go @@ -0,0 +1,397 @@ +package whisper + +import ( + "fmt" + "io" + "runtime" + "strings" + "time" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +type StatefulContext struct { + n int + model *ModelContext + st *whisperState + params *Parameters +} + +// NewStatefulContext creates a new stateful context +func NewStatefulContext(model *ModelContext, params *Parameters) (*StatefulContext, error) { + if model == nil { + return nil, errModelRequired + } + + if params == nil { + return nil, errParametersRequired + } + + c := new(StatefulContext) + c.model = model + c.params = params + + // allocate isolated state per context + ctx, err := model.ctxAccessor().context() + if err != nil { + return nil, err + } + + st := ctx.Whisper_init_state() + if st == nil { + return nil, errUnableToCreateState + } + + c.st = newWhisperState(st) + + // Return success + return c, nil +} + +// DetectedLanguage returns the detected language for the current context data +func (context *StatefulContext) DetectedLanguage() string { + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return "" + } + + st, err := context.st.unsafeState() + if err != nil { + return "" + } + + return whisper.Whisper_lang_str( + ctx.Whisper_full_lang_id_from_state( + st, + ), + ) +} + +// Close frees the whisper state and marks the context as closed. +func (context *StatefulContext) Close() error { + return context.st.close() +} + +// Params returns a high-level parameters wrapper +func (context *StatefulContext) Params() *Parameters { + return context.params +} + +// ResetTimings resets the model performance timing counters. +// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics. +func (context *StatefulContext) ResetTimings() { + context.model.ResetTimings() +} + +// PrintTimings prints the model performance timings to stdout. +// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics. +func (context *StatefulContext) PrintTimings() { + context.model.PrintTimings() +} + +// SystemInfo returns the system information +func (context *StatefulContext) SystemInfo() string { + return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n", + context.params.Threads(), + runtime.NumCPU(), + whisper.Whisper_print_system_info(), + ) +} + +// Use mel data at offset_ms to try and auto-detect the spoken language +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// Returns the probabilities of all languages for this context's state. +func (context *StatefulContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return nil, err + } + + st, err := context.st.unsafeState() + if err != nil { + return nil, err + } + + langProbs, err := ctx.Whisper_lang_auto_detect_with_state(st, offset_ms, n_threads) + if err != nil { + return nil, err + } + + return langProbs, nil +} + +// Process new sample data and return any errors +func (context *StatefulContext) Process( + data []float32, + callEncoderBegin EncoderBeginCallback, + callNewSegment SegmentCallback, + callProgress ProgressCallback, +) error { + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return err + } + + // If the callback is defined then we force on single_segment mode + if callNewSegment != nil { + context.params.SetSingleSegment(true) + } + + lowLevelParams, err := context.params.unsafeParams() + if err != nil { + return err + } + + st, err := context.st.unsafeState() + if err != nil { + return err + } + + if err := ctx.Whisper_full_with_state(st, *lowLevelParams, data, callEncoderBegin, + func(new int) { + if callNewSegment != nil { + num_segments := ctx.Whisper_full_n_segments_from_state(st) + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + callNewSegment(toSegmentFromState(ctx, st, i)) + } + } + }, func(progress int) { + if callProgress != nil { + callProgress(progress) + } + }); err != nil { + return err + } + + // Return success + return nil +} + +// NextSegment returns the next segment from the context buffer +func (context *StatefulContext) NextSegment() (Segment, error) { + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return Segment{}, err + } + + st, err := context.st.unsafeState() + if err != nil { + return Segment{}, err + } + + if context.n >= ctx.Whisper_full_n_segments_from_state(st) { + return Segment{}, io.EOF + } + + result := toSegmentFromState(ctx, st, context.n) + context.n++ + + return result, nil +} + +func (context *StatefulContext) IsMultilingual() bool { + return context.model.IsMultilingual() +} + +// Token helpers +// Deprecated: Use Model.IsText() instead - token checking is model-specific. +func (context *StatefulContext) IsText(t Token) bool { + result, _ := context.model.tokenIdentifier().IsText(t) + return result +} + +// Deprecated: Use Model.IsBEG() instead - token checking is model-specific. +func (context *StatefulContext) IsBEG(t Token) bool { + result, _ := context.model.tokenIdentifier().IsBEG(t) + return result +} + +// Deprecated: Use Model.IsSOT() instead - token checking is model-specific. +func (context *StatefulContext) IsSOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsSOT(t) + return result +} + +// Deprecated: Use Model.IsEOT() instead - token checking is model-specific. +func (context *StatefulContext) IsEOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsEOT(t) + return result +} + +// Deprecated: Use Model.IsPREV() instead - token checking is model-specific. +func (context *StatefulContext) IsPREV(t Token) bool { + result, _ := context.model.tokenIdentifier().IsPREV(t) + return result +} + +// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific. +func (context *StatefulContext) IsSOLM(t Token) bool { + result, _ := context.model.tokenIdentifier().IsSOLM(t) + return result +} + +// Deprecated: Use Model.IsNOT() instead - token checking is model-specific. +func (context *StatefulContext) IsNOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsNOT(t) + return result +} + +func (context *StatefulContext) SetLanguage(lang string) error { + if context.model.ctxAccessor().isClosed() { + // TODO: remove this logic after deprecating the ErrInternalAppError + return ErrModelClosed + } + + if !context.model.IsMultilingual() { + return ErrModelNotMultilingual + } + + return context.params.SetLanguage(lang) +} + +// Deprecated: Use Model.IsLANG() instead - token checking is model-specific. +func (context *StatefulContext) IsLANG(t Token, lang string) bool { + result, _ := context.model.tokenIdentifier().IsLANG(t, lang) + return result +} + +// State-backed helper functions +func toSegmentFromState(ctx *whisper.Context, st *whisper.State, n int) Segment { + return Segment{ + Num: n, + Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text_from_state(st, n)), + Start: time.Duration(ctx.Whisper_full_get_segment_t0_from_state(st, n)) * time.Millisecond * 10, + End: time.Duration(ctx.Whisper_full_get_segment_t1_from_state(st, n)) * time.Millisecond * 10, + Tokens: toTokensFromState(ctx, st, n), + SpeakerTurnNext: ctx.Whisper_full_get_segment_speaker_turn_next_from_state(st, n), + } +} + +func toTokensFromState(ctx *whisper.Context, st *whisper.State, n int) []Token { + result := make([]Token, ctx.Whisper_full_n_tokens_from_state(st, n)) + + for i := 0; i < len(result); i++ { + data := ctx.Whisper_full_get_token_data_from_state(st, n, i) + result[i] = Token{ + Id: int(ctx.Whisper_full_get_token_id_from_state(st, n, i)), + Text: ctx.Whisper_full_get_token_text_from_state(st, n, i), + P: ctx.Whisper_full_get_token_p_from_state(st, n, i), + Start: time.Duration(data.T0()) * time.Millisecond * 10, + End: time.Duration(data.T1()) * time.Millisecond * 10, + } + } + + return result +} + +// Deprecated: Use Params().Language() instead +func (context *StatefulContext) Language() string { + return context.params.Language() +} + +// Deprecated: Use Params().SetAudioCtx() instead +func (context *StatefulContext) SetAudioCtx(n uint) { + context.params.SetAudioCtx(n) +} + +// SetBeamSize implements Context. +// Deprecated: Use Params().SetBeamSize() instead +func (context *StatefulContext) SetBeamSize(v int) { + context.params.SetBeamSize(v) +} + +// SetDuration implements Context. +// Deprecated: Use Params().SetDuration() instead +func (context *StatefulContext) SetDuration(v time.Duration) { + context.params.SetDuration(v) +} + +// SetEntropyThold implements Context. +// Deprecated: Use Params().SetEntropyThold() instead +func (context *StatefulContext) SetEntropyThold(v float32) { + context.params.SetEntropyThold(v) +} + +// SetInitialPrompt implements Context. +// Deprecated: Use Params().SetInitialPrompt() instead +func (context *StatefulContext) SetInitialPrompt(v string) { + context.params.SetInitialPrompt(v) +} + +// SetMaxContext implements Context. +// Deprecated: Use Params().SetMaxContext() instead +func (context *StatefulContext) SetMaxContext(v int) { + context.params.SetMaxContext(v) +} + +// SetMaxSegmentLength implements Context. +// Deprecated: Use Params().SetMaxSegmentLength() instead +func (context *StatefulContext) SetMaxSegmentLength(v uint) { + context.params.SetMaxSegmentLength(v) +} + +// SetMaxTokensPerSegment implements Context. +// Deprecated: Use Params().SetMaxTokensPerSegment() instead +func (context *StatefulContext) SetMaxTokensPerSegment(v uint) { + context.params.SetMaxTokensPerSegment(v) +} + +// SetOffset implements Context. +// Deprecated: Use Params().SetOffset() instead +func (context *StatefulContext) SetOffset(v time.Duration) { + context.params.SetOffset(v) +} + +// SetSplitOnWord implements Context. +// Deprecated: Use Params().SetSplitOnWord() instead +func (context *StatefulContext) SetSplitOnWord(v bool) { + context.params.SetSplitOnWord(v) +} + +// SetTemperature implements Context. +// Deprecated: Use Params().SetTemperature() instead +func (context *StatefulContext) SetTemperature(v float32) { + context.params.SetTemperature(v) +} + +// SetTemperatureFallback implements Context. +// Deprecated: Use Params().SetTemperatureFallback() instead +func (context *StatefulContext) SetTemperatureFallback(v float32) { + context.params.SetTemperatureFallback(v) +} + +// SetThreads implements Context. +// Deprecated: Use Params().SetThreads() instead +func (context *StatefulContext) SetThreads(v uint) { + context.params.SetThreads(v) +} + +// SetTokenSumThreshold implements Context. +// Deprecated: Use Params().SetTokenSumThreshold() instead +func (context *StatefulContext) SetTokenSumThreshold(v float32) { + context.params.SetTokenSumThreshold(v) +} + +// SetTokenThreshold implements Context. +// Deprecated: Use Params().SetTokenThreshold() instead +func (context *StatefulContext) SetTokenThreshold(v float32) { + context.params.SetTokenThreshold(v) +} + +// SetTokenTimestamps implements Context. +// Deprecated: Use Params().SetTokenTimestamps() instead +func (context *StatefulContext) SetTokenTimestamps(v bool) { + context.params.SetTokenTimestamps(v) +} + +// SetTranslate implements Context. +// Deprecated: Use Params().SetTranslate() instead +func (context *StatefulContext) SetTranslate(v bool) { + context.params.SetTranslate(v) +} + +// Make stateful context compatible with the old deprecated interface for +// the simple migration into multi-threaded processing. +var _ Context = (*StatefulContext)(nil) diff --git a/bindings/go/pkg/whisper/stateful_context_test.go b/bindings/go/pkg/whisper/stateful_context_test.go new file mode 100644 index 00000000000..0062aed10fa --- /dev/null +++ b/bindings/go/pkg/whisper/stateful_context_test.go @@ -0,0 +1,81 @@ +package whisper_test + +import ( + "os" + "sync" + "testing" + + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + assert "github.com/stretchr/testify/assert" +) + +// Stateful-specific: parallel processing supported +func TestContext_Parallel_DifferentInputs_Stateful(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + data := helperLoadSample(t, SamplePath) + assert.Greater(len(data), 10) + + // Create half-sample (second half) + half := make([]float32, len(data)/2) + copy(half, data[len(data)/2:]) + + model, err := whisper.NewModelContext(ModelPath) + assert.NoError(err) + defer func() { _ = model.Close() }() + + params1 := helperNewParams(t, model, nil) + params2 := helperNewParams(t, model, nil) + + ctx1, err := whisper.NewStatefulContext(model, params1) + assert.NoError(err) + defer func() { _ = ctx1.Close() }() + ctx2, err := whisper.NewStatefulContext(model, params2) + assert.NoError(err) + defer func() { _ = ctx2.Close() }() + + var wg sync.WaitGroup + var first1, first2 string + var e1, e2 error + wg.Add(2) + + go func() { + defer wg.Done() + e1 = ctx1.Process(data, nil, nil, nil) + if e1 == nil { + seg, err := ctx1.NextSegment() + if err == nil { + first1 = seg.Text + } else { + e1 = err + } + } + }() + + go func() { + defer wg.Done() + e2 = ctx2.Process(half, nil, nil, nil) + if e2 == nil { + seg, err := ctx2.NextSegment() + if err == nil { + first2 = seg.Text + } else { + e2 = err + } + } + }() + + wg.Wait() + assert.NoError(e1) + assert.NoError(e2) + assert.NotEmpty(first1) + assert.NotEmpty(first2) + assert.NotEqual(first1, first2, "first segments should differ for different inputs") +} diff --git a/bindings/go/pkg/whisper/stateless_context.go b/bindings/go/pkg/whisper/stateless_context.go new file mode 100644 index 00000000000..7dbe8be29f7 --- /dev/null +++ b/bindings/go/pkg/whisper/stateless_context.go @@ -0,0 +1,377 @@ +package whisper + +import ( + "fmt" + "io" + "runtime" + "strings" + "time" + + // Bindings + whisper "github.com/ggerganov/whisper.cpp/bindings/go" +) + +type StatelessContext struct { + n int + model *ModelContext + params *Parameters + closed bool +} + +// NewStatelessContext creates a new stateless context backed by the model's context +func NewStatelessContext(model *ModelContext, params *Parameters) (*StatelessContext, error) { + if model == nil { + return nil, errModelRequired + } + + if params == nil { + return nil, errParametersRequired + } + + // Ensure model context is available + if _, err := model.ctxAccessor().context(); err != nil { + return nil, err + } + + c := new(StatelessContext) + c.model = model + c.params = params + + return c, nil +} + +// DetectedLanguage returns the detected language for the current context data +func (context *StatelessContext) DetectedLanguage() string { + if context.closed { + return "" + } + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return "" + } + return whisper.Whisper_lang_str(ctx.Whisper_full_lang_id()) +} + +// Close marks the context as closed. +func (context *StatelessContext) Close() error { + context.closed = true + return nil +} + +// Params returns a high-level parameters wrapper +func (context *StatelessContext) Params() *Parameters { + return context.params +} + +// ResetTimings resets the model performance timing counters. +// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics. +func (context *StatelessContext) ResetTimings() { + context.model.ResetTimings() +} + +// PrintTimings prints the model performance timings to stdout. +// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics. +func (context *StatelessContext) PrintTimings() { + context.model.PrintTimings() +} + +// SystemInfo returns the system information +func (context *StatelessContext) SystemInfo() string { + return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n", + context.params.Threads(), + runtime.NumCPU(), + whisper.Whisper_print_system_info(), + ) +} + +// Use mel data at offset_ms to try and auto-detect the spoken language +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// Returns the probabilities of all languages for this context. +func (context *StatelessContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { + if context.closed { + return nil, ErrModelClosed + } + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return nil, err + } + langProbs, err := ctx.Whisper_lang_auto_detect(offset_ms, n_threads) + if err != nil { + return nil, err + } + return langProbs, nil +} + +// Process new sample data and return any errors +func (context *StatelessContext) Process( + data []float32, + callEncoderBegin EncoderBeginCallback, + callNewSegment SegmentCallback, + callProgress ProgressCallback, +) error { + if context.closed { + return ErrModelClosed + } + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return err + } + // Concurrency guard: prevent concurrent stateless processing on shared model ctx + k := modelKey(context.model) + if !gate().Acquire(k) { + return ErrStatelessBusy + } + defer gate().Release(k) + + // If the callback is defined then we force on single_segment mode + if callNewSegment != nil { + context.params.SetSingleSegment(true) + } + + lowLevelParams, err := context.params.unsafeParams() + if err != nil { + return err + } + + if err := ctx.Whisper_full(*lowLevelParams, data, callEncoderBegin, + func(new int) { + if callNewSegment != nil { + num_segments := ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + callNewSegment(toSegmentFromContext(ctx, i)) + } + } + }, func(progress int) { + if callProgress != nil { + callProgress(progress) + } + }); err != nil { + return err + } + + // Return success + return nil +} + +// NextSegment returns the next segment from the context buffer +func (context *StatelessContext) NextSegment() (Segment, error) { + if context.closed { + return Segment{}, ErrModelClosed + } + ctx, err := context.model.ctxAccessor().context() + if err != nil { + return Segment{}, err + } + + if context.n >= ctx.Whisper_full_n_segments() { + return Segment{}, io.EOF + } + + result := toSegmentFromContext(ctx, context.n) + context.n++ + + return result, nil +} + +func (context *StatelessContext) IsMultilingual() bool { + return context.model.IsMultilingual() +} + +// Token helpers +// Deprecated: Use Model.IsText() instead - token checking is model-specific. +func (context *StatelessContext) IsText(t Token) bool { + result, _ := context.model.tokenIdentifier().IsText(t) + return result +} + +// Deprecated: Use Model.IsBEG() instead - token checking is model-specific. +func (context *StatelessContext) IsBEG(t Token) bool { + result, _ := context.model.tokenIdentifier().IsBEG(t) + return result +} + +// Deprecated: Use Model.IsSOT() instead - token checking is model-specific. +func (context *StatelessContext) IsSOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsSOT(t) + return result +} + +// Deprecated: Use Model.IsEOT() instead - token checking is model-specific. +func (context *StatelessContext) IsEOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsEOT(t) + return result +} + +// Deprecated: Use Model.IsPREV() instead - token checking is model-specific. +func (context *StatelessContext) IsPREV(t Token) bool { + result, _ := context.model.tokenIdentifier().IsPREV(t) + return result +} + +// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific. +func (context *StatelessContext) IsSOLM(t Token) bool { + result, _ := context.model.tokenIdentifier().IsSOLM(t) + return result +} + +// Deprecated: Use Model.IsNOT() instead - token checking is model-specific. +func (context *StatelessContext) IsNOT(t Token) bool { + result, _ := context.model.tokenIdentifier().IsNOT(t) + return result +} + +func (context *StatelessContext) SetLanguage(lang string) error { + if context.closed || context.model.ctxAccessor().isClosed() { + return ErrModelClosed + } + + if !context.model.IsMultilingual() { + return ErrModelNotMultilingual + } + + return context.params.SetLanguage(lang) +} + +// Deprecated: Use Model.IsLANG() instead - token checking is model-specific. +func (context *StatelessContext) IsLANG(t Token, lang string) bool { + result, _ := context.model.tokenIdentifier().IsLANG(t, lang) + return result +} + +// Context-backed helper functions +func toSegmentFromContext(ctx *whisper.Context, n int) Segment { + return Segment{ + Num: n, + Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), + Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, + End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, + Tokens: toTokensFromContext(ctx, n), + SpeakerTurnNext: false, // speaker turn available only with state-backed accessors + } +} + +func toTokensFromContext(ctx *whisper.Context, n int) []Token { + result := make([]Token, ctx.Whisper_full_n_tokens(n)) + + for i := 0; i < len(result); i++ { + data := ctx.Whisper_full_get_token_data(n, i) + result[i] = Token{ + Id: int(ctx.Whisper_full_get_token_id(n, i)), + Text: ctx.Whisper_full_get_token_text(n, i), + P: ctx.Whisper_full_get_token_p(n, i), + Start: time.Duration(data.T0()) * time.Millisecond * 10, + End: time.Duration(data.T1()) * time.Millisecond * 10, + } + } + + return result +} + +// Deprecated: Use Params().Language() instead +func (context *StatelessContext) Language() string { + return context.params.Language() +} + +// Deprecated: Use Params().SetAudioCtx() instead +func (context *StatelessContext) SetAudioCtx(n uint) { + context.params.SetAudioCtx(n) +} + +// SetBeamSize implements Context. +// Deprecated: Use Params().SetBeamSize() instead +func (context *StatelessContext) SetBeamSize(v int) { + context.params.SetBeamSize(v) +} + +// SetDuration implements Context. +// Deprecated: Use Params().SetDuration() instead +func (context *StatelessContext) SetDuration(v time.Duration) { + context.params.SetDuration(v) +} + +// SetEntropyThold implements Context. +// Deprecated: Use Params().SetEntropyThold() instead +func (context *StatelessContext) SetEntropyThold(v float32) { + context.params.SetEntropyThold(v) +} + +// SetInitialPrompt implements Context. +// Deprecated: Use Params().SetInitialPrompt() instead +func (context *StatelessContext) SetInitialPrompt(v string) { + context.params.SetInitialPrompt(v) +} + +// SetMaxContext implements Context. +// Deprecated: Use Params().SetMaxContext() instead +func (context *StatelessContext) SetMaxContext(v int) { + context.params.SetMaxContext(v) +} + +// SetMaxSegmentLength implements Context. +// Deprecated: Use Params().SetMaxSegmentLength() instead +func (context *StatelessContext) SetMaxSegmentLength(v uint) { + context.params.SetMaxSegmentLength(v) +} + +// SetMaxTokensPerSegment implements Context. +// Deprecated: Use Params().SetMaxTokensPerSegment() instead +func (context *StatelessContext) SetMaxTokensPerSegment(v uint) { + context.params.SetMaxTokensPerSegment(v) +} + +// SetOffset implements Context. +// Deprecated: Use Params().SetOffset() instead +func (context *StatelessContext) SetOffset(v time.Duration) { + context.params.SetOffset(v) +} + +// SetSplitOnWord implements Context. +// Deprecated: Use Params().SetSplitOnWord() instead +func (context *StatelessContext) SetSplitOnWord(v bool) { + context.params.SetSplitOnWord(v) +} + +// SetTemperature implements Context. +// Deprecated: Use Params().SetTemperature() instead +func (context *StatelessContext) SetTemperature(v float32) { + context.params.SetTemperature(v) +} + +// SetTemperatureFallback implements Context. +// Deprecated: Use Params().SetTemperatureFallback() instead +func (context *StatelessContext) SetTemperatureFallback(v float32) { + context.params.SetTemperatureFallback(v) +} + +// SetThreads implements Context. +// Deprecated: Use Params().SetThreads() instead +func (context *StatelessContext) SetThreads(v uint) { + context.params.SetThreads(v) +} + +// SetTokenSumThreshold implements Context. +// Deprecated: Use Params().SetTokenSumThreshold() instead +func (context *StatelessContext) SetTokenSumThreshold(v float32) { + context.params.SetTokenSumThreshold(v) +} + +// SetTokenThreshold implements Context. +// Deprecated: Use Params().SetTokenThreshold() instead +func (context *StatelessContext) SetTokenThreshold(v float32) { + context.params.SetTokenThreshold(v) +} + +// SetTokenTimestamps implements Context. +// Deprecated: Use Params().SetTokenTimestamps() instead +func (context *StatelessContext) SetTokenTimestamps(v bool) { + context.params.SetTokenTimestamps(v) +} + +// SetTranslate implements Context. +// Deprecated: Use Params().SetTranslate() instead +func (context *StatelessContext) SetTranslate(v bool) { + context.params.SetTranslate(v) +} + +var _ Context = (*StatelessContext)(nil) diff --git a/bindings/go/pkg/whisper/stateless_context_test.go b/bindings/go/pkg/whisper/stateless_context_test.go new file mode 100644 index 00000000000..0eb867d1914 --- /dev/null +++ b/bindings/go/pkg/whisper/stateless_context_test.go @@ -0,0 +1,52 @@ +package whisper_test + +import ( + "sync" + "testing" + + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + assert "github.com/stretchr/testify/assert" +) + +// Ensure stateless contexts cannot process in parallel without isolation +func TestStatelessContext_NotParallelSafe(t *testing.T) { + data := helperLoadSample(t, SamplePath) + + model, closeModel := helperNewModelContext(t) + defer closeModel() + + params := helperNewParams(t, model, nil) + + // Create two stateless contexts sharing the same underlying model context + ctx1, err := whisper.NewStatelessContext(model, params) + assert.NoError(t, err) + defer func() { _ = ctx1.Close() }() + + ctx2, err := whisper.NewStatelessContext(model, params) + assert.NoError(t, err) + defer func() { _ = ctx2.Close() }() + + // Run both in parallel - expect a panic or error from underlying whisper_full + // We capture panics to assert the behavior. + var wg sync.WaitGroup + wg.Add(2) + + var err1, err2 error + + go func() { + defer wg.Done() + err1 = ctx1.Process(data, nil, nil, nil) + }() + + go func() { + defer wg.Done() + err2 = ctx2.Process(data, nil, nil, nil) + }() + + wg.Wait() + + // At least one should return ErrStatelessBusy + if err1 != whisper.ErrStatelessBusy && err2 != whisper.ErrStatelessBusy { + t.Fatalf("expected ErrStatelessBusy when processing in parallel with StatelessContext, got err1=%v err2=%v", err1, err2) + } +} diff --git a/bindings/go/pkg/whisper/test_helpers_test.go b/bindings/go/pkg/whisper/test_helpers_test.go new file mode 100644 index 00000000000..15fedc9613a --- /dev/null +++ b/bindings/go/pkg/whisper/test_helpers_test.go @@ -0,0 +1,129 @@ +package whisper_test + +import ( + "os" + "testing" + + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + wav "github.com/go-audio/wav" +) + +func helperLoadSample(tb testing.TB, path string) []float32 { + tb.Helper() + fh, err := os.Open(path) + if err != nil { + tb.Fatalf("open sample: %v", err) + } + defer func() { _ = fh.Close() }() + + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + if err != nil { + tb.Fatalf("decode wav: %v", err) + } + if dec.NumChans != 1 { + tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans) + } + return buf.AsFloat32Buffer().Data +} + +// helperLoadSampleWithMeta loads wav and returns samples with sample rate and channels +func helperLoadSampleWithMeta(tb testing.TB, path string) ([]float32, int, int) { + tb.Helper() + fh, err := os.Open(path) + if err != nil { + tb.Fatalf("open sample: %v", err) + } + defer func() { _ = fh.Close() }() + + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + if err != nil { + tb.Fatalf("decode wav: %v", err) + } + if dec.NumChans != 1 { + tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans) + } + return buf.AsFloat32Buffer().Data, int(dec.SampleRate), int(dec.NumChans) +} + +func helperNewModel(t *testing.T) (whisper.Model, func()) { + t.Helper() + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + model, err := whisper.New(ModelPath) + if err != nil { + t.Fatalf("load model: %v", err) + } + return model, func() { _ = model.Close() } +} + +func helperNewModelContext(t *testing.T) (*whisper.ModelContext, func()) { + t.Helper() + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + model, err := whisper.NewModelContext(ModelPath) + if err != nil { + t.Fatalf("load model ctx: %v", err) + } + return model, func() { _ = model.Close() } +} + +func helperNewParams(t *testing.T, model *whisper.ModelContext, configure whisper.ParamsConfigure) *whisper.Parameters { + t.Helper() + params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, configure) + if err != nil { + t.Fatalf("new params: %v", err) + } + return params +} + +func helperProcessOnce(t *testing.T, ctx whisper.Context, data []float32) { + t.Helper() + if err := ctx.Process(data, nil, nil, nil); err != nil { + t.Fatalf("process: %v", err) + } +} + +func helperFirstSegmentText(t *testing.T, ctx whisper.Context) string { + t.Helper() + seg, err := ctx.NextSegment() + if err != nil { + t.Fatalf("next segment: %v", err) + } + return seg.Text +} + +// helperNewStatelessContext creates a fresh stateless context and returns a cleanup func +func helperNewStatelessContext(t *testing.T) (whisper.Context, func()) { + t.Helper() + model, closeModel := helperNewModelContext(t) + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatelessContext(model, params) + if err != nil { + t.Fatalf("new stateless context: %v", err) + } + cleanup := func() { + _ = ctx.Close() + closeModel() + } + return ctx, cleanup +} + +// helperNewStatefulContext creates a fresh stateful context and returns a cleanup func +func helperNewStatefulContext(t *testing.T) (whisper.Context, func()) { + t.Helper() + model, closeModel := helperNewModelContext(t) + params := helperNewParams(t, model, nil) + ctx, err := whisper.NewStatefulContext(model, params) + if err != nil { + t.Fatalf("new stateful context: %v", err) + } + cleanup := func() { + _ = ctx.Close() + closeModel() + } + return ctx, cleanup +} diff --git a/bindings/go/pkg/whisper/token_identifier.go b/bindings/go/pkg/whisper/token_identifier.go new file mode 100644 index 00000000000..1f8ab712b90 --- /dev/null +++ b/bindings/go/pkg/whisper/token_identifier.go @@ -0,0 +1,115 @@ +package whisper + +import whisper "github.com/ggerganov/whisper.cpp/bindings/go" + +type tokenIdentifier struct { + ctx *ctxAccessor +} + +func newTokenIdentifier(whisperContext *ctxAccessor) *tokenIdentifier { + return &tokenIdentifier{ + ctx: whisperContext, + } +} + +// Token type checking methods (model-specific vocabulary) +func (ti *tokenIdentifier) IsBEG(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_beg(), nil +} + +func (ti *tokenIdentifier) IsEOT(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_eot(), nil +} + +func (ti *tokenIdentifier) IsSOT(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_sot(), nil +} + +func (ti *tokenIdentifier) IsPREV(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_prev(), nil +} + +func (ti *tokenIdentifier) IsSOLM(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_solm(), nil +} + +func (ti *tokenIdentifier) IsNOT(t Token) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + return whisper.Token(t.Id) == ctx.Whisper_token_not(), nil +} + +func (ti *tokenIdentifier) IsLANG(t Token, lang string) (bool, error) { + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + if id := ctx.Whisper_lang_id(lang); id >= 0 { + return whisper.Token(t.Id) == ctx.Whisper_token_lang(id), nil + } + + return false, nil +} + +func (ti *tokenIdentifier) IsText(t Token) (bool, error) { + // Check if it's any of the special tokens + if isBeg, _ := ti.IsBEG(t); isBeg { + return false, nil + } + + if isSot, _ := ti.IsSOT(t); isSot { + return false, nil + } + + ctx, err := ti.ctx.context() + if err != nil { + return false, err + } + + if whisper.Token(t.Id) >= ctx.Whisper_token_eot() { + return false, nil + } + + if isPrev, _ := ti.IsPREV(t); isPrev { + return false, nil + } + + if isSolm, _ := ti.IsSOLM(t); isSolm { + return false, nil + } + + if isNot, _ := ti.IsNOT(t); isNot { + return false, nil + } + + return true, nil +} diff --git a/bindings/go/pkg/whisper/util_test.go b/bindings/go/pkg/whisper/util_test.go index 8ea2d5b4781..a2fadca5885 100644 --- a/bindings/go/pkg/whisper/util_test.go +++ b/bindings/go/pkg/whisper/util_test.go @@ -1,6 +1,18 @@ package whisper_test +import ( + "os" + "testing" +) + const ( - ModelPath = "../../models/ggml-small.en.bin" - SamplePath = "../../samples/jfk.wav" + ModelPath = "../../models/ggml-small.en.bin" + ModelTinydiarizePath = "../../models/ggml-small.en-tdrz.bin" + SamplePath = "../../samples/jfk.wav" + MultiSpeakerSamplePath = "../../samples/a13.wav" ) + +func TestMain(m *testing.M) { + // whisper.DisableLogs() + os.Exit(m.Run()) +} diff --git a/bindings/go/pkg/whisper/whisper_ctx.go b/bindings/go/pkg/whisper/whisper_ctx.go new file mode 100644 index 00000000000..ab935c76a87 --- /dev/null +++ b/bindings/go/pkg/whisper/whisper_ctx.go @@ -0,0 +1,36 @@ +package whisper + +import whisper "github.com/ggerganov/whisper.cpp/bindings/go" + +type ctxAccessor struct { + ctx *whisper.Context +} + +func newCtxAccessor(ctx *whisper.Context) *ctxAccessor { + return &ctxAccessor{ + ctx: ctx, + } +} + +func (ctx *ctxAccessor) close() error { + if ctx.ctx == nil { + return nil + } + + ctx.ctx.Whisper_free() + ctx.ctx = nil + + return nil +} + +func (ctx *ctxAccessor) isClosed() bool { + return ctx.ctx == nil +} + +func (ctx *ctxAccessor) context() (*whisper.Context, error) { + if ctx.isClosed() { + return nil, ErrModelClosed + } + + return ctx.ctx, nil +} diff --git a/bindings/go/pkg/whisper/whisper_ctx_test.go b/bindings/go/pkg/whisper/whisper_ctx_test.go new file mode 100644 index 00000000000..d308f11da7f --- /dev/null +++ b/bindings/go/pkg/whisper/whisper_ctx_test.go @@ -0,0 +1,85 @@ +package whisper + +import ( + "os" + "testing" + + w "github.com/ggerganov/whisper.cpp/bindings/go" + assert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testModelPathCtx = "../../models/ggml-small.en.bin" + +func TestWhisperCtx_NilWrapper(t *testing.T) { + wctx := newCtxAccessor(nil) + + assert.True(t, wctx.isClosed()) + + raw, err := wctx.context() + assert.Nil(t, raw) + require.ErrorIs(t, err, ErrModelClosed) + + require.NoError(t, wctx.close()) + // idempotent + require.NoError(t, wctx.close()) +} + +func TestWhisperCtx_Lifecycle(t *testing.T) { + if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", testModelPathCtx) + } + + raw := w.Whisper_init(testModelPathCtx) + require.NotNil(t, raw) + + wctx := newCtxAccessor(raw) + assert.False(t, wctx.isClosed()) + + got, err := wctx.context() + require.NoError(t, err) + require.NotNil(t, got) + + // close frees underlying ctx and marks closed + require.NoError(t, wctx.close()) + assert.True(t, wctx.isClosed()) + + got, err = wctx.context() + assert.Nil(t, got) + require.ErrorIs(t, err, ErrModelClosed) + + // idempotent + require.NoError(t, wctx.close()) + // no further free; raw already freed by wctx.Close() +} + +func TestWhisperCtx_FromModelLifecycle(t *testing.T) { + if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", testModelPathCtx) + } + + modelNew, err := New(testModelPathCtx) + require.NoError(t, err) + require.NotNil(t, modelNew) + + model := modelNew.(*ModelContext) + + wc := model.ctxAccessor() + require.NotNil(t, wc) + + // Should be usable before model.Close + raw, err := wc.context() + require.NoError(t, err) + require.NotNil(t, raw) + + // Close model should close underlying context + require.NoError(t, model.Close()) + + assert.True(t, wc.isClosed()) + raw, err = wc.context() + assert.Nil(t, raw) + require.ErrorIs(t, err, ErrModelClosed) + + // Idempotent close on wrapper + require.NoError(t, wc.close()) +} diff --git a/bindings/go/pkg/whisper/whisper_state.go b/bindings/go/pkg/whisper/whisper_state.go new file mode 100644 index 00000000000..cee48948731 --- /dev/null +++ b/bindings/go/pkg/whisper/whisper_state.go @@ -0,0 +1,32 @@ +package whisper + +import whisper "github.com/ggerganov/whisper.cpp/bindings/go" + +type whisperState struct { + state *whisper.State +} + +func newWhisperState(state *whisper.State) *whisperState { + return &whisperState{ + state: state, + } +} + +func (s *whisperState) close() error { + if s.state == nil { + return nil + } + + s.state.Whisper_free_state() + s.state = nil + + return nil +} + +func (s *whisperState) unsafeState() (*whisper.State, error) { + if s.state == nil { + return nil, ErrModelClosed + } + + return s.state, nil +} diff --git a/bindings/go/pkg/whisper/whisper_state_test.go b/bindings/go/pkg/whisper/whisper_state_test.go new file mode 100644 index 00000000000..2c4c6dd305e --- /dev/null +++ b/bindings/go/pkg/whisper/whisper_state_test.go @@ -0,0 +1,53 @@ +package whisper + +import ( + "os" + "testing" + + w "github.com/ggerganov/whisper.cpp/bindings/go" + assert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testModelPathState = "../../models/ggml-small.en.bin" + +func TestWhisperState_NilWrapper(t *testing.T) { + ws := newWhisperState(nil) + + state, err := ws.unsafeState() + assert.Nil(t, state) + require.ErrorIs(t, err, ErrModelClosed) + + require.NoError(t, ws.close()) + // idempotent + require.NoError(t, ws.close()) +} + +func TestWhisperState_Lifecycle(t *testing.T) { + if _, err := os.Stat(testModelPathState); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", testModelPathState) + } + + ctx := w.Whisper_init(testModelPathState) + require.NotNil(t, ctx) + defer ctx.Whisper_free() + + state := ctx.Whisper_init_state() + require.NotNil(t, state) + + ws := newWhisperState(state) + + got, err := ws.unsafeState() + require.NoError(t, err) + require.NotNil(t, got) + + // close frees underlying state and marks closed + require.NoError(t, ws.close()) + + got, err = ws.unsafeState() + assert.Nil(t, got) + require.ErrorIs(t, err, ErrModelClosed) + + // idempotent + require.NoError(t, ws.close()) +} diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 3ef73414d90..023a33d26db 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -2,6 +2,7 @@ package whisper import ( "errors" + "sync" "unsafe" ) @@ -14,6 +15,7 @@ import ( #cgo darwin LDFLAGS: -lggml-metal -lggml-blas #cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics #include +#include #include extern void callNewSegment(void* user_data, int new); @@ -59,6 +61,22 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_ params.progress_callback_user_data = (void*)(ctx); return params; } + +// Disable all C-side logging (whisper.cpp and ggml) +static void go_cb_log_disable(enum ggml_log_level level, const char * text, void * user_data) { + (void) level; (void) text; (void) user_data; +} + +static void whisper_log_disable_all(void) { + ggml_log_set(go_cb_log_disable, NULL); + whisper_log_set(go_cb_log_disable, NULL); +} + +// Enable default logging (stdout) for whisper.cpp and ggml +static void whisper_log_enable_default(void) { + ggml_log_set(NULL, NULL); + whisper_log_set(NULL, NULL); +} */ import "C" @@ -67,10 +85,13 @@ import "C" type ( Context C.struct_whisper_context + State C.struct_whisper_state Token C.whisper_token TokenData C.struct_whisper_token_data SamplingStrategy C.enum_whisper_sampling_strategy Params C.struct_whisper_full_params + Timings C.struct_whisper_timings + ContextParams C.struct_whisper_context_params ) /////////////////////////////////////////////////////////////////////////////// @@ -96,6 +117,12 @@ var ( ErrInvalidLanguage = errors.New("invalid language") ) +// DisableLogs disables all logging coming from the C libraries (whisper.cpp and ggml). +// Call once early in program startup if you want to silence device/backend prints. +func DisableLogs() { + C.whisper_log_disable_all() +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS @@ -111,11 +138,54 @@ func Whisper_init(path string) *Context { } } +// Whisper_context_default_params returns default model context params +func Whisper_context_default_params() ContextParams { + return ContextParams(C.whisper_context_default_params()) +} + +// SetUseGPU enables or disables GPU acceleration on the model context (if available) +func (p *ContextParams) SetUseGPU(v bool) { + if v { + p.use_gpu = C.bool(true) + } else { + p.use_gpu = C.bool(false) + } +} + +// SetGPUDevice selects the GPU device index for the model context (CUDA) +func (p *ContextParams) SetGPUDevice(n int) { + p.gpu_device = C.int(n) +} + +// Whisper_init_with_params allocates and initializes a model using custom context params +func Whisper_init_with_params(path string, params ContextParams) *Context { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil { + return (*Context)(ctx) + } else { + return nil + } +} + // Frees all memory allocated by the model. func (ctx *Context) Whisper_free() { C.whisper_free((*C.struct_whisper_context)(ctx)) } +// Allocates a new state associated with the context. Returns nil on failure. +func (ctx *Context) Whisper_init_state() *State { + if s := C.whisper_init_state((*C.struct_whisper_context)(ctx)); s != nil { + return (*State)(s) + } + return nil +} + +// Frees all memory allocated by the state. +func (s *State) Whisper_free_state() { + C.whisper_free_state((*C.struct_whisper_state)(s)) +} + // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the provided whisper context. func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { @@ -126,6 +196,15 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error { } } +// Convert RAW PCM audio to log mel spectrogram into the provided state. +func (ctx *Context) Whisper_pcm_to_mel_with_state(state *State, data []float32, threads int) error { + if C.whisper_pcm_to_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // This can be used to set a custom log mel spectrogram inside the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 @@ -137,6 +216,15 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error { } } +// Set a custom log mel spectrogram into the provided state. +func (ctx *Context) Whisper_set_mel_with_state(state *State, data []float32, n_mel int) error { + if C.whisper_set_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. @@ -148,6 +236,15 @@ func (ctx *Context) Whisper_encode(offset, threads int) error { } } +// Run the Whisper encoder using the provided state. +func (ctx *Context) Whisper_encode_with_state(state *State, offset, threads int) error { + if C.whisper_encode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. // tokens + n_tokens is the provided context for the decoder. @@ -160,6 +257,15 @@ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error { } } +// Run the Whisper decoder using the provided state. +func (ctx *Context) Whisper_decode_with_state(state *State, tokens []Token, past, threads int) error { + if C.whisper_decode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) { @@ -181,6 +287,10 @@ func (ctx *Context) Whisper_lang_id(lang string) int { return int(C.whisper_lang_id(C.CString(lang))) } +func Whisper_lang_id_str(lang string) int { + return int(C.whisper_lang_id(C.CString(lang))) +} + // Largest language id (i.e. number of available languages - 1) func Whisper_lang_max_id() int { return int(C.whisper_lang_max_id()) @@ -205,6 +315,16 @@ func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float3 } } +// Use mel data at offset_ms to auto-detect language using the provided state. +func (ctx *Context) Whisper_lang_auto_detect_with_state(state *State, offset_ms, n_threads int) ([]float32, error) { + probs := make([]float32, Whisper_lang_max_id()+1) + if n := int(C.whisper_lang_auto_detect_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 { + return nil, ErrAutoDetectFailed + } else { + return probs, nil + } +} + func (ctx *Context) Whisper_n_len() int { return int(C.whisper_n_len((*C.struct_whisper_context)(ctx))) } @@ -290,6 +410,32 @@ func (ctx *Context) Whisper_reset_timings() { C.whisper_reset_timings((*C.struct_whisper_context)(ctx)) } +// TimingsGo is a Go-friendly copy of whisper_timings +type TimingsGo struct { + SampleMS float32 + EncodeMS float32 + DecodeMS float32 + BatchdMS float32 + PromptMS float32 +} + +// Whisper_get_timings_go retrieves timing counters and converts them to TimingsGo +func (ctx *Context) Whisper_get_timings_go() (TimingsGo, bool) { + t := C.whisper_get_timings((*C.struct_whisper_context)(ctx)) + if t == nil { + return TimingsGo{}, false + } + // The C struct is 5 consecutive floats; reinterpret and copy + arr := (*[5]C.float)(unsafe.Pointer(t)) + return TimingsGo{ + SampleMS: float32(arr[0]), + EncodeMS: float32(arr[1]), + DecodeMS: float32(arr[2]), + BatchdMS: float32(arr[3]), + PromptMS: float32(arr[4]), + }, true +} + // Print system information func Whisper_print_system_info() string { return C.GoString(C.whisper_print_system_info()) @@ -323,6 +469,28 @@ func (ctx *Context) Whisper_full( } } +// Run the entire model using the provided state: PCM -> mel -> encoder -> decoder -> text +func (ctx *Context) Whisper_full_with_state( + state *State, + params Params, + samples []float32, + encoderBeginCallback func() bool, + newSegmentCallback func(int), + progressCallback func(int), +) error { + registerEncoderBeginCallback(ctx, encoderBeginCallback) + registerNewSegmentCallback(ctx, newSegmentCallback) + registerProgressCallback(ctx, progressCallback) + defer registerEncoderBeginCallback(ctx, nil) + defer registerNewSegmentCallback(ctx, nil) + defer registerProgressCallback(ctx, nil) + if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 { + return nil + } else { + return ErrConversionFailed + } +} + // Split the input audio in chunks and process each chunk separately using whisper_full() // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. @@ -357,102 +525,157 @@ func (ctx *Context) Whisper_full_n_segments() int { return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx))) } +func (ctx *Context) Whisper_full_n_segments_from_state(state *State) int { + return int(C.whisper_full_n_segments_from_state((*C.struct_whisper_state)(state))) +} + // Get the start and end time of the specified segment. func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 { return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment))) } +func (ctx *Context) Whisper_full_get_segment_t0_from_state(state *State, segment int) int64 { + return int64(C.whisper_full_get_segment_t0_from_state((*C.struct_whisper_state)(state), C.int(segment))) +} + // Get the start and end time of the specified segment. func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 { return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment))) } +func (ctx *Context) Whisper_full_get_segment_t1_from_state(state *State, segment int) int64 { + return int64(C.whisper_full_get_segment_t1_from_state((*C.struct_whisper_state)(state), C.int(segment))) +} + // Get the text of the specified segment. func (ctx *Context) Whisper_full_get_segment_text(segment int) string { return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment))) } +func (ctx *Context) Whisper_full_get_segment_text_from_state(state *State, segment int) string { + return C.GoString(C.whisper_full_get_segment_text_from_state((*C.struct_whisper_state)(state), C.int(segment))) +} + // Get number of tokens in the specified segment. func (ctx *Context) Whisper_full_n_tokens(segment int) int { return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment))) } +func (ctx *Context) Whisper_full_n_tokens_from_state(state *State, segment int) int { + return int(C.whisper_full_n_tokens_from_state((*C.struct_whisper_state)(state), C.int(segment))) +} + // Get the token text of the specified token index in the specified segment. func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string { return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } +func (ctx *Context) Whisper_full_get_token_text_from_state(state *State, segment int, token int) string { + return C.GoString(C.whisper_full_get_token_text_from_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token))) +} + // Get the token of the specified token index in the specified segment. func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token { return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } +func (ctx *Context) Whisper_full_get_token_id_from_state(state *State, segment int, token int) Token { + return Token(C.whisper_full_get_token_id_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) +} + // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData { return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } +func (ctx *Context) Whisper_full_get_token_data_from_state(state *State, segment int, token int) TokenData { + return TokenData(C.whisper_full_get_token_data_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) +} + // Get the probability of the specified token in the specified segment. func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } +func (ctx *Context) Whisper_full_get_token_p_from_state(state *State, segment int, token int) float32 { + return float32(C.whisper_full_get_token_p_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token))) +} + +func (ctx *Context) Whisper_full_lang_id_from_state(state *State) int { + return int(C.whisper_full_lang_id_from_state((*C.struct_whisper_state)(state))) +} + +func (ctx *Context) Whisper_n_len_from_state(state *State) int { + return int(C.whisper_n_len_from_state((*C.struct_whisper_state)(state))) +} + +func (ctx *Context) Whisper_get_logits_from_state(state *State) []float32 { + return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_logits_from_state((*C.struct_whisper_state)(state))))[:ctx.Whisper_n_vocab()] +} + +// Get whether the next segment is predicted as a speaker turn (tinydiarize) +func (ctx *Context) Whisper_full_get_segment_speaker_turn_next_from_state(state *State, segment int) bool { + return bool(C.whisper_full_get_segment_speaker_turn_next_from_state((*C.struct_whisper_state)(state), C.int(segment))) +} + /////////////////////////////////////////////////////////////////////////////// // CALLBACKS var ( - cbNewSegment = make(map[unsafe.Pointer]func(int)) - cbProgress = make(map[unsafe.Pointer]func(int)) - cbEncoderBegin = make(map[unsafe.Pointer]func() bool) + cbNewSegment sync.Map // map[unsafe.Pointer]func(int) + cbProgress sync.Map // map[unsafe.Pointer]func(int) + cbEncoderBegin sync.Map // map[unsafe.Pointer]func() bool ) func registerNewSegmentCallback(ctx *Context, fn func(int)) { + k := unsafe.Pointer(ctx) if fn == nil { - delete(cbNewSegment, unsafe.Pointer(ctx)) + cbNewSegment.Delete(k) } else { - cbNewSegment[unsafe.Pointer(ctx)] = fn + cbNewSegment.Store(k, fn) } } func registerProgressCallback(ctx *Context, fn func(int)) { + k := unsafe.Pointer(ctx) if fn == nil { - delete(cbProgress, unsafe.Pointer(ctx)) + cbProgress.Delete(k) } else { - cbProgress[unsafe.Pointer(ctx)] = fn + cbProgress.Store(k, fn) } } func registerEncoderBeginCallback(ctx *Context, fn func() bool) { + k := unsafe.Pointer(ctx) if fn == nil { - delete(cbEncoderBegin, unsafe.Pointer(ctx)) + cbEncoderBegin.Delete(k) } else { - cbEncoderBegin[unsafe.Pointer(ctx)] = fn + cbEncoderBegin.Store(k, fn) } } //export callNewSegment func callNewSegment(user_data unsafe.Pointer, new C.int) { - if fn, ok := cbNewSegment[user_data]; ok { - fn(int(new)) + if v, ok := cbNewSegment.Load(user_data); ok { + v.(func(int))(int(new)) } } //export callProgress func callProgress(user_data unsafe.Pointer, progress C.int) { - if fn, ok := cbProgress[user_data]; ok { - fn(int(progress)) + if v, ok := cbProgress.Load(user_data); ok { + v.(func(int))(int(progress)) } } //export callEncoderBegin func callEncoderBegin(user_data unsafe.Pointer) C.bool { - if fn, ok := cbEncoderBegin[user_data]; ok { - if fn() { + if v, ok := cbEncoderBegin.Load(user_data); ok { + if v.(func() bool)() { return C.bool(true) - } else { - return C.bool(false) } + return C.bool(false) } return true } diff --git a/bindings/go/whisper_test.go b/bindings/go/whisper_test.go index 40648ffa8d4..330981fb453 100644 --- a/bindings/go/whisper_test.go +++ b/bindings/go/whisper_test.go @@ -1,8 +1,10 @@ package whisper_test import ( + "errors" "os" "runtime" + "sync" "testing" "time" @@ -17,6 +19,11 @@ const ( SamplePath = "samples/jfk.wav" ) +func TestMain(m *testing.M) { + whisper.DisableLogs() + os.Exit(m.Run()) +} + func Test_Whisper_000(t *testing.T) { assert := assert.New(t) if _, err := os.Stat(ModelPath); os.IsNotExist(err) { @@ -39,7 +46,7 @@ func Test_Whisper_001(t *testing.T) { // Open samples fh, err := os.Open(SamplePath) assert.NoError(err) - defer fh.Close() + defer func() { _ = fh.Close() }() // Read samples d := wav.NewDecoder(fh) @@ -89,7 +96,7 @@ func Test_Whisper_003(t *testing.T) { // Open samples fh, err := os.Open(SamplePath) assert.NoError(err) - defer fh.Close() + defer func() { _ = fh.Close() }() // Read samples d := wav.NewDecoder(fh) @@ -111,3 +118,157 @@ func Test_Whisper_003(t *testing.T) { t.Logf("%s: %f", whisper.Whisper_lang_str(i), p) } } + +func Test_Whisper_State_Init_Free(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + state := ctx.Whisper_init_state() + assert.NotNil(state) + state.Whisper_free_state() +} + +func Test_Whisper_Full_With_State(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Open samples + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer func() { _ = fh.Close() }() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + assert.NoError(err) + data := buf.AsFloat32Buffer().Data + + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + state := ctx.Whisper_init_state() + assert.NotNil(state) + defer state.Whisper_free_state() + + params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + // Run using state + err = ctx.Whisper_full_with_state(state, params, data, nil, nil, nil) + assert.NoError(err) + + // Validate results are stored in state + nSegments := ctx.Whisper_full_n_segments_from_state(state) + assert.GreaterOrEqual(nSegments, 1) + text := ctx.Whisper_full_get_segment_text_from_state(state, 0) + assert.NotEmpty(text) +} + +func Test_Whisper_Lang_Auto_Detect_With_State(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Open samples + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer func() { _ = fh.Close() }() + + // Read samples + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + assert.NoError(err) + data := buf.AsFloat32Buffer().Data + + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + state := ctx.Whisper_init_state() + assert.NotNil(state) + defer state.Whisper_free_state() + + threads := runtime.NumCPU() + // Prepare mel into state then detect + assert.NoError(ctx.Whisper_pcm_to_mel_with_state(state, data, threads)) + probs, err := ctx.Whisper_lang_auto_detect_with_state(state, 0, threads) + assert.NoError(err) + assert.Equal(whisper.Whisper_lang_max_id()+1, len(probs)) +} + +func Test_Whisper_Concurrent_With_State(t *testing.T) { + assert := assert.New(t) + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + if _, err := os.Stat(SamplePath); os.IsNotExist(err) { + t.Skip("Skipping test, sample not found:", SamplePath) + } + + // Load audio once + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer func() { _ = fh.Close() }() + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + assert.NoError(err) + data := buf.AsFloat32Buffer().Data + + ctx := whisper.Whisper_init(ModelPath) + assert.NotNil(ctx) + defer ctx.Whisper_free() + + // Each goroutine has its own state + state1 := ctx.Whisper_init_state() + state2 := ctx.Whisper_init_state() + assert.NotNil(state1) + assert.NotNil(state2) + defer state1.Whisper_free_state() + defer state2.Whisper_free_state() + + params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + + var wg sync.WaitGroup + var mu sync.Mutex // guard calls into shared ctx, per upstream note not thread-safe for same context + errs := make(chan error, 2) + + worker := func(state *whisper.State) { + defer wg.Done() + mu.Lock() + err := ctx.Whisper_full_with_state(state, params, data, nil, nil, nil) + if err == nil { + n := ctx.Whisper_full_n_segments_from_state(state) + if n <= 0 { + err = errors.New("no segments") + } else { + _ = ctx.Whisper_full_get_segment_text_from_state(state, 0) + } + } + mu.Unlock() + errs <- err + } + + wg.Add(2) + go worker(state1) + go worker(state2) + wg.Wait() + close(errs) + + for e := range errs { + assert.NoError(e) + } +}