From 647c5f67880f0bdbe30c6535968f88d84f093c17 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:23:12 +0900 Subject: [PATCH 01/31] Add Params#new_segment_callback= method --- bindings/ruby/ext/ruby_whisper.cpp | 31 ++++++++++++++++++++++++++++++ bindings/ruby/ext/ruby_whisper.h | 1 + 2 files changed, 32 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9d9334539b8..96c435209a6 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -73,6 +73,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->new_segment_callback = Qnil; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -205,6 +206,28 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { }; rwp->params.encoder_begin_callback_user_data = &is_aborted; } + { + // This cannot be used later because it is not incremented when new_segment_callback is not given. + static int n_segments = 0; + + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } + + for (int i = 0; i < n_new; i++) { + const int i_segment = n_segments + i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + n_segments += n_new; + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); @@ -365,6 +388,12 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback = value; + return value; +} void Init_whisper() { mWhisper = rb_define_module("Whisper"); @@ -412,6 +441,8 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); + + rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); } #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8c35b7cb65c..988750a8268 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -10,6 +10,7 @@ typedef struct { typedef struct { struct whisper_full_params params; bool diarize; + VALUE new_segment_callback; } ruby_whisper_params; #endif From 81e6df3bab7662da5379db51f28a989db7408c02 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:23:43 +0900 Subject: [PATCH 02/31] Add tests for Params#new_segment_callback= --- bindings/ruby/tests/test_whisper.rb | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 410b5248a89..a496b3ae58f 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,29 @@ def test_whisper } end + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 71b65b00ccf1816c9ea8a247fb30f71bc09707d3 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:29:50 +0900 Subject: [PATCH 03/31] Group tests for #transcribe --- bindings/ruby/tests/test_whisper.rb | 64 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index a496b3ae58f..0095ea20725 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -116,38 +116,38 @@ def test_split_on_word assert !@params.split_on_word end - def test_whisper - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - params = Whisper::Params.new - params.print_timestamps = false - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - @whisper.transcribe(jfk, params) {|text| - assert_match /ask not what your country can do for you, ask what you can do for your country/, text - } - end - - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) + sub_test_case "#transcribe" do + def setup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @params = Whisper::Params.new + @params.print_timestamps = false + @jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + end + + def test_whisper + @whisper.transcribe(@jfk, @params) {|text| + assert_match /ask not what your country can do for you, ask what you can do for your country/, text + } + end + + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + @whisper.transcribe(@jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + @whisper.transcribe(@jfk, @params) + end end def test_build From e8243b7059b235e99c9abb1b2e9dabd3bc3ad178 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:09:30 +0900 Subject: [PATCH 04/31] Don't use static for thread-safety --- bindings/ruby/ext/ruby_whisper.cpp | 36 +++++++++++++----------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 96c435209a6..b16c67e00ef 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -206,28 +206,24 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { }; rwp->params.encoder_begin_callback_user_data = &is_aborted; } - { - // This cannot be used later because it is not incremented when new_segment_callback is not given. - static int n_segments = 0; - rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } - for (int i = 0; i < n_new; i++) { - const int i_segment = n_segments + i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } - n_segments += n_new; - }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; - } + int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; --i) { + const int i_segment = n_segments - i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); From d0d55f5f5b79d416c477a8adf44cc1677dd3604e Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:11:34 +0900 Subject: [PATCH 05/31] Set new_segment_callback only when necessary --- bindings/ruby/ext/ruby_whisper.cpp | 34 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index b16c67e00ef..77028939302 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -207,23 +207,25 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } + if (!NIL_P(rwp->new_segment_callback)) { + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } - int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; --i) { - const int i_segment = n_segments - i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } - }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; --i) { + const int i_segment = n_segments - i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); From 6077d90410ed40e228c3bf1376b70fbbc19f8612 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:15:38 +0900 Subject: [PATCH 06/31] Remove redundant check --- bindings/ruby/ext/ruby_whisper.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 77028939302..1ee2453867d 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -210,10 +210,6 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { if (!NIL_P(rwp->new_segment_callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } - int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; --i) { const int i_segment = n_segments - i; From 8fdbb2031c8a93f0cff3897e88d565836cc88b6e Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 12:12:09 +0900 Subject: [PATCH 07/31] [skip ci] Add Ruby version README --- bindings/ruby/.gitignore | 1 - bindings/ruby/README.md | 63 +++++++++++++++++++++++++++++++++++ bindings/ruby/extsources.yaml | 2 -- 3 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 bindings/ruby/README.md diff --git a/bindings/ruby/.gitignore b/bindings/ruby/.gitignore index 6ff6e5f2119..e04a90a9c69 100644 --- a/bindings/ruby/.gitignore +++ b/bindings/ruby/.gitignore @@ -1,4 +1,3 @@ -README.md LICENSE pkg/ lib/whisper.* diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md new file mode 100644 index 00000000000..4c6d0e86587 --- /dev/null +++ b/bindings/ruby/README.md @@ -0,0 +1,63 @@ +whispercpp +========== + +![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg) + +Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model. + +Installation +------------ + +Install the gem and add to the application's Gemfile by executing: + + $ bundle add whispercpp + +If bundler is not being used to manage dependencies, install the gem by executing: + + $ gem install whispercpp + +Usage +----- + +NOTE: This gem is still in development. API is not stable for now. + +```ruby +require "whisper" + +whisper = Whisper::Context.new("path/to/model.bin") + +params = Whisper::Params.new +params.language = "en" +params.offset = 10_000 +params.duration = 60_000 +params.max_text_tokens = 300 +params.translate = true +params.print_timestamps = false +params.new_segment_callback = ->(output, t0, t1, index) { + puts "segment #{index}: #{t0}ms -> #{t1}ms: #{output}" +} + +whisper.transcribe("path/to/audio.wav", params) do |whole_text| + puts whole_text +end + +``` + +### Preparing model ### + +Use script to download model file(s): + +```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp +sh ./models/download-ggml-model.sh base.en +``` + +There are some types of models. See [models][] page for details. + +### Preparing audio file ### + +Currently, whisper.cpp accepts only 16-bit WAV files. + +[whisper.cpp]: https://github.com/ggerganov/whisper.cpp +[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models diff --git a/bindings/ruby/extsources.yaml b/bindings/ruby/extsources.yaml index 1a4b4d25bdb..94f941dff32 100644 --- a/bindings/ruby/extsources.yaml +++ b/bindings/ruby/extsources.yaml @@ -27,6 +27,4 @@ ../../examples: - ext/dr_wav.h ../..: -- README.md - LICENSE - From abbee8473cb01f557bbbf75f9ffd1f1565928b77 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 02:36:12 +0900 Subject: [PATCH 08/31] Revert "Group tests for #transcribe" This reverts commit 71b65b00ccf1816c9ea8a247fb30f71bc09707d3. --- bindings/ruby/tests/test_whisper.rb | 64 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 0095ea20725..a496b3ae58f 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -116,38 +116,38 @@ def test_split_on_word assert !@params.split_on_word end - sub_test_case "#transcribe" do - def setup - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - @params = Whisper::Params.new - @params.print_timestamps = false - @jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - end - - def test_whisper - @whisper.transcribe(@jfk, @params) {|text| - assert_match /ask not what your country can do for you, ask what you can do for your country/, text - } - end - - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - @whisper.transcribe(@jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - @whisper.transcribe(@jfk, @params) - end + def test_whisper + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) {|text| + assert_match /ask not what your country can do for you, ask what you can do for your country/, text + } + end + + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) end def test_build From 67b375a66ef84c08e444ef4a481db9a2951169f4 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 02:36:26 +0900 Subject: [PATCH 09/31] Revert "Add tests for Params#new_segment_callback=" This reverts commit 81e6df3bab7662da5379db51f28a989db7408c02. --- bindings/ruby/tests/test_whisper.rb | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index a496b3ae58f..410b5248a89 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,29 +127,6 @@ def test_whisper } end - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 050a116d29cd02e874c113b0da7946c45507093a Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 04:16:23 +0900 Subject: [PATCH 10/31] Add test for Context#full_n_segments --- bindings/ruby/tests/test_whisper.rb | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 410b5248a89..09da562547c 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,28 @@ def test_whisper } end + sub_test_case "After transcription" do + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) + end + end + + def whisper + self.class.whisper + end + + def test_full_n_segments + assert_equal 1, whisper.full_n_segments + end + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From b152263cc391c0f41b35db495c7d6673f5a0dbdb Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 04:16:39 +0900 Subject: [PATCH 11/31] Add Context#full_n_segments --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 1ee2453867d..5e912cb9180 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -240,6 +240,12 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { return self; } +static VALUE ruby_whisper_full_n_segments(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_n_segments(rw->context)); +} + /* * params.language = "auto" | "en", etc... */ @@ -398,6 +404,7 @@ void Init_whisper() { rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From 59db172ed2e4c73bc65087e0a0518d3457bee542 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:15:31 +0900 Subject: [PATCH 12/31] Add tests for lang API --- bindings/ruby/tests/test_whisper.rb | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 09da562547c..b2408ca5a32 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -149,6 +149,22 @@ def test_full_n_segments end end + def test_lang_max_id + assert_kind_of Integer, Whisper.lang_max_id + end + + def test_lang_id + assert_equal 0, Whisper.lang_id("en") + end + + def test_lang_str + assert_equal "en", Whisper.lang_str(0) + end + + def test_lang_str_full + assert_equal "english", Whisper.lang_str_full(0) + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 207a3f18111d218d0ceca4f518b9637950873873 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:15:48 +0900 Subject: [PATCH 13/31] Add lang API --- bindings/ruby/ext/ruby_whisper.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 5e912cb9180..67669e64682 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -36,6 +36,22 @@ VALUE mWhisper; VALUE cContext; VALUE cParams; +static VALUE ruby_whisper_s_lang_max_id(VALUE self) { + return INT2NUM(whisper_lang_max_id()); +} + +static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { + return INT2NUM(whisper_lang_id(StringValueCStr(lang))); +} + +static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { + return rb_str_new2(whisper_lang_str(NUM2INT(id))); +} + +static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { + return rb_str_new2(whisper_lang_str_full(NUM2INT(id))); +} + static void ruby_whisper_free(ruby_whisper *rw) { if (rw->context) { whisper_free(rw->context); @@ -400,6 +416,11 @@ void Init_whisper() { cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); + rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); + rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); + rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); + rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); From 22035fb61eb5971004110a5aa3c58704b24717a7 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:29:09 +0900 Subject: [PATCH 14/31] Add tests for Context#full_lang_id API --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index b2408ca5a32..69324b3d160 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -147,6 +147,10 @@ def whisper def test_full_n_segments assert_equal 1, whisper.full_n_segments end + + def test_full_lang_id + assert_equal 0, whisper.full_lang_id + end end def test_lang_max_id From 8799616b3e349f6b485b6d3f37dbee3590bf027d Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:29:19 +0900 Subject: [PATCH 15/31] Add Context#full_lang_id --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 67669e64682..0ec0b646ef6 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -262,6 +262,12 @@ static VALUE ruby_whisper_full_n_segments(VALUE self) { return INT2NUM(whisper_full_n_segments(rw->context)); } +static VALUE ruby_whisper_full_lang_id(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_lang_id(rw->context)); +} + /* * params.language = "auto" | "en", etc... */ @@ -426,6 +432,7 @@ void Init_whisper() { rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); + rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From eef03e4d79e2e147a7952c621f5280f2d4093314 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:01:38 +0900 Subject: [PATCH 16/31] Add abnormal test cases for lang --- bindings/ruby/tests/test_whisper.rb | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 69324b3d160..7dfa067311c 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -159,14 +159,23 @@ def test_lang_max_id def test_lang_id assert_equal 0, Whisper.lang_id("en") + assert_raise ArgumentError do + Whisper.lang_id("non existing language") + end end def test_lang_str assert_equal "en", Whisper.lang_str(0) + assert_raise IndexError do + Whisper.lang_str(Whisper.lang_max_id + 1) + end end def test_lang_str_full assert_equal "english", Whisper.lang_str_full(0) + assert_raise IndexError do + Whisper.lang_str_full(Whisper.lang_max_id + 1) + end end def test_build From 3f2f232ec129b76fd60b75216fcda4adb345216d Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:01:52 +0900 Subject: [PATCH 17/31] Raise appropriate errors from lang APIs --- bindings/ruby/ext/ruby_whisper.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 0ec0b646ef6..db91cb6dbef 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -41,15 +41,30 @@ static VALUE ruby_whisper_s_lang_max_id(VALUE self) { } static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { - return INT2NUM(whisper_lang_id(StringValueCStr(lang))); + const char * lang_str = StringValueCStr(lang); + const int id = whisper_lang_id(lang_str); + if (-1 == id) { + rb_raise(rb_eArgError, "language not found: %s", lang_str); + } + return INT2NUM(id); } static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { - return rb_str_new2(whisper_lang_str(NUM2INT(id))); + const int lang_id = NUM2INT(id); + const char * str = whisper_lang_str(lang_id); + if (nullptr == str) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str); } static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { - return rb_str_new2(whisper_lang_str_full(NUM2INT(id))); + const int lang_id = NUM2INT(id); + const char * str_full = whisper_lang_str_full(lang_id); + if (nullptr == str_full) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str_full); } static void ruby_whisper_free(ruby_whisper *rw) { From 7144916dbde66f65fd7979dd2511ba3f4d9649a5 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:25:22 +0900 Subject: [PATCH 18/31] Add tests for Context#full_get_segment_t{0,1} API --- bindings/ruby/tests/test_whisper.rb | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 7dfa067311c..af2aca9fa49 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -151,6 +151,25 @@ def test_full_n_segments def test_full_lang_id assert_equal 0, whisper.full_lang_id end + + def test_full_get_segment_t0 + assert_equal 0, whisper.full_get_segment_t0(0) + assert_raise IndexError do + whisper.full_get_segment_t0(whisper.full_n_segments) + end + assert_raise IndexError do + whisper.full_get_segment_t0(-1) + end + end + + def test_full_get_segment_t1 + t1 = whisper.full_get_segment_t1(0) + assert_kind_of Integer, t1 + assert t1 > 0 + assert_raise IndexError do + whisper.full_get_segment_t1(whisper.full_n_segments) + end + end end def test_lang_max_id From 5ebd2b535b5188c1ffe5f63e52da752524c0dd33 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:25:35 +0900 Subject: [PATCH 19/31] Add Context#full_get_segment_t{0,1} --- bindings/ruby/ext/ruby_whisper.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index db91cb6dbef..dc78ff9d258 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -283,6 +283,30 @@ static VALUE ruby_whisper_full_lang_id(VALUE self) { return INT2NUM(whisper_full_lang_id(rw->context)); } +static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) { + const int c_i_segment = NUM2INT(i_segment); + if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) { + rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment); + } + return c_i_segment; +} + +static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); + return INT2NUM(t0); +} + +static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); + return INT2NUM(t1); +} + /* * params.language = "auto" | "en", etc... */ @@ -448,6 +472,8 @@ void Init_whisper() { rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); + rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); + rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From 3951acce9b657167113ac7c3b94d30e6ced8b4bb Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:41:28 +0900 Subject: [PATCH 20/31] Add tests for Context#full_get_segment_speaker_turn_next API --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index af2aca9fa49..207e8127a4e 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -170,6 +170,10 @@ def test_full_get_segment_t1 whisper.full_get_segment_t1(whisper.full_n_segments) end end + + def test_full_get_segment_speaker_turn_next + assert_false whisper.full_get_segment_speaker_turn_next(0) + end end def test_lang_max_id From 9e04f7a1983387d96fd452ce14f12db9c302e647 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:41:36 +0900 Subject: [PATCH 21/31] Add Context#full_get_segment_speaker_turn_next --- bindings/ruby/ext/ruby_whisper.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index dc78ff9d258..d8134ff0b2c 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -307,6 +307,14 @@ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { return INT2NUM(t1); } +static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); + return speaker_turn_next ? Qtrue : Qfalse; +} + /* * params.language = "auto" | "en", etc... */ @@ -474,6 +482,7 @@ void Init_whisper() { rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); + rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From 0672e6f740a028713b361391400188dd4061e99a Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:44:35 +0900 Subject: [PATCH 22/31] Add tests for Context#full_get_segment_text --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 207e8127a4e..48b95af94e5 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -174,6 +174,10 @@ def test_full_get_segment_t1 def test_full_get_segment_speaker_turn_next assert_false whisper.full_get_segment_speaker_turn_next(0) end + + def test_full_get_segment_text + assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0) + end end def test_lang_max_id From 5e350b139ca9528fc3318a3f3b069a3e14110cd4 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:44:48 +0900 Subject: [PATCH 23/31] Add Context#full_get_setgment_text --- bindings/ruby/ext/ruby_whisper.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index d8134ff0b2c..56abc6022f4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -315,6 +315,14 @@ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i return speaker_turn_next ? Qtrue : Qfalse; } +static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); + return rb_str_new2(text); +} + /* * params.language = "auto" | "en", etc... */ @@ -483,6 +491,7 @@ void Init_whisper() { rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); + rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From d3a5157ce1e94f5f16fd764ff97017164a318500 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 10:55:15 +0900 Subject: [PATCH 24/31] Add tests for Params#new_segment_callback= --- bindings/ruby/tests/test_whisper.rb | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 48b95af94e5..74e8b7fc7a7 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,56 @@ def test_whisper } end + def test_new_segment_callback + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same whisper, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 + end + } + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_closure + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + + search_word = "what" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + text = context.full_get_segment_text(i_segment) + if text.include?(search_word) + t0 = context.full_get_segment_t0(i_segment) + t1 = context.full_get_segment_t1(i_segment) + raise "search word '#{search_word}' found at between #{t0} and #{t1}" + end + end + } + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + assert_raise RuntimeError do + whisper.transcribe(jfk, @params) + end + end + sub_test_case "After transcription" do class << self attr_reader :whisper From a71d12ec657ddbb5fc6cac899d1910b0d12b3444 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 10:56:02 +0900 Subject: [PATCH 25/31] Run new segment callback --- bindings/ruby/ext/ruby_whisper.cpp | 21 ++++++++++----------- bindings/ruby/ext/ruby_whisper.h | 1 + 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 56abc6022f4..d0344d423d4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -240,18 +240,17 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { if (!NIL_P(rwp->new_segment_callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; --i) { - const int i_segment = n_segments - i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } + ruby_whisper *rw;; + VALUE context = *(VALUE *)user_data; + Data_Get_Struct(context, ruby_whisper, rw); + VALUE callback = rw->new_segment_callback; + + // Currently, doesn't support state and user_data because + // those require to resolve GC-related problems. + rb_funcall(callback, rb_intern("call"), 4, context, Qnil, INT2NUM(n_new), Qnil); }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + rw->new_segment_callback = rwp->new_segment_callback; + rwp->params.new_segment_callback_user_data = &self; } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 988750a8268..1481bfa9e17 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,6 +5,7 @@ typedef struct { struct whisper_context *context; + VALUE new_segment_callback; } ruby_whisper; typedef struct { From a6028f321d31dce60e1d4675c71779204569971d Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 20:42:44 +0900 Subject: [PATCH 26/31] Split tests to multiple files --- bindings/ruby/tests/test_callback.rb | 56 ++++++++ bindings/ruby/tests/test_package.rb | 28 ++++ bindings/ruby/tests/test_params.rb | 112 ++++++++++++++++ bindings/ruby/tests/test_whisper.rb | 184 +-------------------------- 4 files changed, 198 insertions(+), 182 deletions(-) create mode 100644 bindings/ruby/tests/test_callback.rb create mode 100644 bindings/ruby/tests/test_package.rb create mode 100644 bindings/ruby/tests/test_params.rb diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb new file mode 100644 index 00000000000..644fc80d295 --- /dev/null +++ b/bindings/ruby/tests/test_callback.rb @@ -0,0 +1,56 @@ +require "test/unit" +require "whisper" + +class TestCallback < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + + def setup + @params = Whisper::Params.new + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @whisper, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 + end + } + + @whisper.transcribe(@audio, @params) + end + + def test_new_segment_callback_closure + search_word = "what" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + text = context.full_get_segment_text(i_segment) + if text.include?(search_word) + t0 = context.full_get_segment_t0(i_segment) + t1 = context.full_get_segment_t1(i_segment) + raise "search word '#{search_word}' found at between #{t0} and #{t1}" + end + end + } + + assert_raise RuntimeError do + @whisper.transcribe(@audio, @params) + end + end +end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb new file mode 100644 index 00000000000..9d7527340f2 --- /dev/null +++ b/bindings/ruby/tests/test_package.rb @@ -0,0 +1,28 @@ +require 'test/unit' +require 'tempfile' +require 'tmpdir' +require 'shellwords' + +class TestPackage < Test::Unit::TestCase + def test_build + Tempfile.create do |file| + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert_path_exist file.to_path + end + end + + sub_test_case "Building binary on installation" do + def setup + system "rake", "build", exception: true + end + + def test_install + filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] + basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" + Dir.mktmpdir do |dir| + system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true + assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) + end + end + end +end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb new file mode 100644 index 00000000000..4484feeeff1 --- /dev/null +++ b/bindings/ruby/tests/test_params.rb @@ -0,0 +1,112 @@ +require 'whisper' + +class TestParams < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + end + + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens + end + + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end +end diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 74e8b7fc7a7..5ebb8151c65 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -1,121 +1,13 @@ -TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) - require 'whisper' require 'test/unit' -require 'tempfile' -require 'tmpdir' -require 'shellwords' class TestWhisper < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + def setup @params = Whisper::Params.new end - def test_language - @params.language = "en" - assert_equal @params.language, "en" - @params.language = "auto" - assert_equal @params.language, "auto" - end - - def test_offset - @params.offset = 10_000 - assert_equal @params.offset, 10_000 - @params.offset = 0 - assert_equal @params.offset, 0 - end - - def test_duration - @params.duration = 60_000 - assert_equal @params.duration, 60_000 - @params.duration = 0 - assert_equal @params.duration, 0 - end - - def test_max_text_tokens - @params.max_text_tokens = 300 - assert_equal @params.max_text_tokens, 300 - @params.max_text_tokens = 0 - assert_equal @params.max_text_tokens, 0 - end - - def test_translate - @params.translate = true - assert @params.translate - @params.translate = false - assert !@params.translate - end - - def test_no_context - @params.no_context = true - assert @params.no_context - @params.no_context = false - assert !@params.no_context - end - - def test_single_segment - @params.single_segment = true - assert @params.single_segment - @params.single_segment = false - assert !@params.single_segment - end - - def test_print_special - @params.print_special = true - assert @params.print_special - @params.print_special = false - assert !@params.print_special - end - - def test_print_progress - @params.print_progress = true - assert @params.print_progress - @params.print_progress = false - assert !@params.print_progress - end - - def test_print_realtime - @params.print_realtime = true - assert @params.print_realtime - @params.print_realtime = false - assert !@params.print_realtime - end - - def test_print_timestamps - @params.print_timestamps = true - assert @params.print_timestamps - @params.print_timestamps = false - assert !@params.print_timestamps - end - - def test_suppress_blank - @params.suppress_blank = true - assert @params.suppress_blank - @params.suppress_blank = false - assert !@params.suppress_blank - end - - def test_suppress_non_speech_tokens - @params.suppress_non_speech_tokens = true - assert @params.suppress_non_speech_tokens - @params.suppress_non_speech_tokens = false - assert !@params.suppress_non_speech_tokens - end - - def test_token_timestamps - @params.token_timestamps = true - assert @params.token_timestamps - @params.token_timestamps = false - assert !@params.token_timestamps - end - - def test_split_on_word - @params.split_on_word = true - assert @params.split_on_word - @params.split_on_word = false - assert !@params.split_on_word - end - def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) params = Whisper::Params.new @@ -127,56 +19,6 @@ def test_whisper } end - def test_new_segment_callback - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - - @params.new_segment_callback = ->(context, state, n_new, user_data) { - assert_kind_of Integer, n_new - assert n_new > 0 - assert_same whisper, context - - n_segments = context.full_n_segments - n_new.times do |i| - i_segment = n_segments - 1 + i - start_time = context.full_get_segment_t0(i_segment) * 10 - end_time = context.full_get_segment_t1(i_segment) * 10 - text = context.full_get_segment_text(i_segment) - - assert_kind_of Integer, start_time - assert start_time >= 0 - assert_kind_of Integer, end_time - assert end_time > 0 - assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 - end - } - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_closure - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - - search_word = "what" - @params.new_segment_callback = ->(context, state, n_new, user_data) { - n_segments = context.full_n_segments - n_new.times do |i| - i_segment = n_segments - 1 + i - text = context.full_get_segment_text(i_segment) - if text.include?(search_word) - t0 = context.full_get_segment_t0(i_segment) - t1 = context.full_get_segment_t1(i_segment) - raise "search word '#{search_word}' found at between #{t0} and #{t1}" - end - end - } - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - assert_raise RuntimeError do - whisper.transcribe(jfk, @params) - end - end - sub_test_case "After transcription" do class << self attr_reader :whisper @@ -254,26 +96,4 @@ def test_lang_str_full Whisper.lang_str_full(Whisper.lang_max_id + 1) end end - - def test_build - Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) - assert_path_exist file.to_path - end - end - - sub_test_case "Building binary on installation" do - def setup - system "rake", "build", exception: true - end - - def test_install - filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] - basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" - Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true - assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) - end - end - end end From c20afc38b784fba9b20ed8b8b4b21da2e8d264b0 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:43:58 +0900 Subject: [PATCH 27/31] Use container struct for new segment callback --- bindings/ruby/ext/ruby_whisper.cpp | 24 +++++++++++++----------- bindings/ruby/ext/ruby_whisper.h | 9 +++++++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index d0344d423d4..9ccdc2a6f7c 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -102,9 +102,14 @@ static VALUE ruby_whisper_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; + ruby_whisper_callback_user_data *new_segment_callback_user_data; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - rwp->new_segment_callback = Qnil; + new_segment_callback_user_data = ALLOC(ruby_whisper_callback_user_data); + new_segment_callback_user_data->context = nullptr; + new_segment_callback_user_data->user_data = Qnil; + new_segment_callback_user_data->callback = Qnil; + rwp->new_segment_callback_user_data = new_segment_callback_user_data; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -238,19 +243,16 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - if (!NIL_P(rwp->new_segment_callback)) { + if (!NIL_P(rwp->new_segment_callback_user_data->callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - ruby_whisper *rw;; - VALUE context = *(VALUE *)user_data; - Data_Get_Struct(context, ruby_whisper, rw); - VALUE callback = rw->new_segment_callback; + const ruby_whisper_callback_user_data *container = (ruby_whisper_callback_user_data *)user_data; - // Currently, doesn't support state and user_data because + // Currently, doesn't support state because // those require to resolve GC-related problems. - rb_funcall(callback, rb_intern("call"), 4, context, Qnil, INT2NUM(n_new), Qnil); + rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); }; - rw->new_segment_callback = rwp->new_segment_callback; - rwp->params.new_segment_callback_user_data = &self; + rwp->new_segment_callback_user_data->context = &self; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_user_data; } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { @@ -467,7 +469,7 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback = value; + rwp->new_segment_callback_user_data->callback = value; return value; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 1481bfa9e17..033b780e94b 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -3,15 +3,20 @@ #include "whisper.h" +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; +} ruby_whisper_callback_user_data; + typedef struct { struct whisper_context *context; - VALUE new_segment_callback; } ruby_whisper; typedef struct { struct whisper_full_params params; bool diarize; - VALUE new_segment_callback; + ruby_whisper_callback_user_data *new_segment_callback_user_data; } ruby_whisper_params; #endif From 1e72d62c2ce7c6b132fb243ee1a2c18b51043399 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:54:04 +0900 Subject: [PATCH 28/31] Add tests for Params#new_segment_callback_user_data= --- bindings/ruby/tests/test_callback.rb | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 644fc80d295..5697079fbad 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -53,4 +53,14 @@ def test_new_segment_callback_closure @whisper.transcribe(@audio, @params) end end + + def test_new_segment_callback_user_data + udata = Object.new + @params.new_segment_callback_user_data = udata + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_same udata, user_data + } + + @whisper.transcribe(@audio, @params) + end end From 0a9957d76deca8c3b93c9b9edd3545b2721856c0 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:54:36 +0900 Subject: [PATCH 29/31] Add Whisper::Params#new_user_callback_user_data= --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9ccdc2a6f7c..168cb98320d 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -472,6 +472,12 @@ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE valu rwp->new_segment_callback_user_data->callback = value; return value; } +static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_user_data->user_data = value; + return value; +} void Init_whisper() { mWhisper = rb_define_module("Whisper"); @@ -532,6 +538,7 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); + rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); } #ifdef __cplusplus } From 73934b580748eb3f0ffa3db4c525bdef62f46422 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 23:20:30 +0900 Subject: [PATCH 30/31] Add GC-related test for new segment callback --- bindings/ruby/tests/test_callback.rb | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 5697079fbad..80a5f4dfae6 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -63,4 +63,14 @@ def test_new_segment_callback_user_data @whisper.transcribe(@audio, @params) end + + def test_new_segment_callback_user_data_gc + @params.new_segment_callback_user_data = "My user data" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_equal "My user data", user_data + } + GC.start + + assert_same @whisper, @whisper.transcribe(@audio, @params) + end end From 1a3ff7cc6117b832dd6156097e1b89f01fff6d38 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 23:20:48 +0900 Subject: [PATCH 31/31] Protect new segment callback related structs from GC --- bindings/ruby/ext/ruby_whisper.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 168cb98320d..8dd935dda08 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -86,9 +86,12 @@ void rb_whisper_free(ruby_whisper *rw) { } void rb_whisper_params_mark(ruby_whisper_params *rwp) { + rb_gc_mark(rwp->new_segment_callback_user_data->user_data); + rb_gc_mark(rwp->new_segment_callback_user_data->callback); } void rb_whisper_params_free(ruby_whisper_params *rwp) { + // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); free(rwp); } @@ -110,6 +113,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { new_segment_callback_user_data->user_data = Qnil; new_segment_callback_user_data->callback = Qnil; rwp->new_segment_callback_user_data = new_segment_callback_user_data; + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); }