diff --git a/CMakeLists.txt b/CMakeLists.txt index c253488..b117275 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,12 +86,14 @@ target_sources( PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c + src/transcription-utils.cpp src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp src/model-utils/model-infos.cpp src/whisper-utils/whisper-processing.cpp src/whisper-utils/whisper-utils.cpp src/whisper-utils/silero-vad-onnx.cpp + src/whisper-utils/token-buffer-thread.cpp src/translation/translation.cpp src/utils.cpp) diff --git a/cmake/BuildWhispercpp.cmake b/cmake/BuildWhispercpp.cmake index d9b8d96..a77efaa 100644 --- a/cmake/BuildWhispercpp.cmake +++ b/cmake/BuildWhispercpp.cmake @@ -107,12 +107,12 @@ elseif(WIN32) install(FILES ${WHISPER_DLLS} DESTINATION "obs-plugins/64bit") else() - set(Whispercpp_Build_GIT_TAG "f22d27a385d34b1e544031efe8aa2e3d73922791") + set(Whispercpp_Build_GIT_TAG "7395c70a748753e3800b63e3422a2b558a097c80") set(WHISPER_EXTRA_CXX_FLAGS "-fPIC") set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_BLAS=OFF -DWHISPER_CUBLAS=OFF -DWHISPER_OPENBLAS=OFF -DWHISPER_NO_AVX=ON -DWHISPER_NO_AVX2=ON) - # On Linux and MacOS build a static Whisper library + # On Linux build a static Whisper library ExternalProject_Add( Whispercpp_Build DOWNLOAD_EXTRACT_TIMESTAMP true @@ -133,7 +133,7 @@ else() ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR) - # on Linux and MacOS add the static Whisper library to the link line + # add the static Whisper library to the link line add_library(Whispercpp::Whisper STATIC IMPORTED) set_target_properties( Whispercpp::Whisper diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 6f7c1e9..089dbf0 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -51,3 +51,7 @@ translate_add_context="Translate with context" whisper_translate="Translate to English (Whisper)" buffer_size_msec="Buffer size (ms)" overlap_size_msec="Overlap size (ms)" +suppress_sentences="Suppress sentences (each line)" +translate_output="Translation output" +dtw_token_timestamps="DTW token timestamps" +buffered_output="Buffered output (Experimental)" diff --git a/src/captions-thread.h b/src/captions-thread.h deleted file mode 100644 index 1cdb079..0000000 --- a/src/captions-thread.h +++ /dev/null @@ -1,118 +0,0 @@ -#ifndef CAPTIONS_THREAD_H -#define CAPTIONS_THREAD_H - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "plugin-support.h" - -class CaptionMonitor { -public: - // default constructor - CaptionMonitor() = default; - - ~CaptionMonitor() - { - { - std::lock_guard lock(queueMutex); - stop = true; - } - condVar.notify_all(); - workerThread.join(); - } - - void initialize(std::function callback_, size_t maxSize_, - std::chrono::seconds maxTime_) - { - this->callback = callback_; - this->maxSize = maxSize_; - this->maxTime = maxTime_; - this->initialized = true; - this->workerThread = std::thread(&CaptionMonitor::monitor, this); - } - - void addWords(const std::vector &words) - { - { - std::lock_guard lock(queueMutex); - for (const auto &word : words) { - wordQueue.push_back(word); - } - this->newDataAvailable = true; - } - condVar.notify_all(); - } - -private: - void monitor() - { - obs_log(LOG_INFO, "CaptionMonitor::monitor"); - auto startTime = std::chrono::steady_clock::now(); - while (true) { - std::unique_lock lock(this->queueMutex); - // wait for new data or stop signal - this->condVar.wait(lock, - [this] { return this->newDataAvailable || this->stop; }); - - if (this->stop) { - break; - } - - if (this->wordQueue.empty()) { - continue; - } - - // emit up to maxSize words from the wordQueue - std::vector emitted; - while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) { - emitted.push_back(this->wordQueue.front()); - this->wordQueue.pop_front(); - } - // emit the caption, joining the words with a space - std::string output; - for (const auto &word : emitted) { - output += word + " "; - } - this->callback(output); - // push back the words that were emitted, in reverse order - for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) { - this->wordQueue.push_front(*it); - } - - if (this->wordQueue.size() >= this->maxSize || - std::chrono::steady_clock::now() - startTime >= this->maxTime) { - // flush the queue if it's full or we've reached the max time - size_t words_to_flush = - std::min(this->wordQueue.size(), this->maxSize); - for (size_t i = 0; i < words_to_flush; ++i) { - wordQueue.pop_front(); - } - startTime = std::chrono::steady_clock::now(); - } - - newDataAvailable = false; - } - obs_log(LOG_INFO, "CaptionMonitor::monitor: done"); - } - - std::deque wordQueue; - std::thread workerThread; - std::mutex queueMutex; - std::condition_variable condVar; - std::function callback; - size_t maxSize; - std::chrono::seconds maxTime; - bool stop; - bool initialized = false; - bool newDataAvailable = false; -}; - -#endif // CAPTIONS_THREAD_H diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 374ec2a..7817a5d 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -17,25 +17,13 @@ #include "translation/translation.h" #include "whisper-utils/silero-vad-onnx.h" -#include "captions-thread.h" +#include "whisper-utils/whisper-processing.h" +#include "whisper-utils/token-buffer-thread.h" #define MAX_PREPROC_CHANNELS 10 #define MT_ obs_module_text -enum DetectionResult { - DETECTION_RESULT_UNKNOWN = 0, - DETECTION_RESULT_SILENCE = 1, - DETECTION_RESULT_SPEECH = 2, -}; - -struct DetectionResultWithText { - DetectionResult result; - std::string text; - uint64_t start_timestamp_ms; - uint64_t end_timestamp_ms; -}; - struct transcription_filter_data { obs_source_t *context; // obs filter source (this filter) size_t channels; // number of channels @@ -64,7 +52,7 @@ struct transcription_filter_data { struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; /* Resampler */ - audio_resampler_t *resampler; + audio_resampler_t *resampler_to_whisper; /* whisper */ std::string whisper_model_path; @@ -90,15 +78,16 @@ struct transcription_filter_data { bool translate = false; std::string source_lang; std::string target_lang; + std::string translation_output; bool buffered_output = false; + bool enable_token_ts_dtw = false; + std::string suppress_sentences; // Last transcription result std::string last_text; // Text source to output the subtitles - obs_weak_source_t *text_source; - char *text_source_name; - std::mutex *text_source_mutex; + std::string text_source_name; // Callback to set the text in the output text source (subtitles) std::function setTextCallback; // Output file path to write the subtitles @@ -115,7 +104,7 @@ struct transcription_filter_data { // translation context struct translation_context translation_ctx; - CaptionMonitor captions_monitor; + TokenBufferThread captions_monitor; // ctor transcription_filter_data() @@ -125,11 +114,9 @@ struct transcription_filter_data { copy_buffers[i] = nullptr; } context = nullptr; - resampler = nullptr; + resampler_to_whisper = nullptr; whisper_model_path = ""; whisper_context = nullptr; - text_source = nullptr; - text_source_mutex = nullptr; whisper_buf_mutex = nullptr; whisper_ctx_mutex = nullptr; wshiper_thread_cv = nullptr; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index c12b2a7..8056a06 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -4,6 +4,7 @@ #include "plugin-support.h" #include "transcription-filter.h" #include "transcription-filter-data.h" +#include "transcription-utils.h" #include "model-utils/model-downloader.h" #include "whisper-utils/whisper-processing.h" #include "whisper-utils/whisper-language.h" @@ -132,18 +133,8 @@ void transcription_filter_destroy(void *data) obs_log(gf->log_level, "filter destroy"); shutdown_whisper_thread(gf); - if (gf->text_source_name) { - bfree(gf->text_source_name); - gf->text_source_name = nullptr; - } - - if (gf->text_source) { - obs_weak_source_release(gf->text_source); - gf->text_source = nullptr; - } - - if (gf->resampler) { - audio_resampler_destroy(gf->resampler); + if (gf->resampler_to_whisper) { + audio_resampler_destroy(gf->resampler_to_whisper); } { @@ -159,87 +150,14 @@ void transcription_filter_destroy(void *data) delete gf->whisper_buf_mutex; delete gf->whisper_ctx_mutex; delete gf->wshiper_thread_cv; - delete gf->text_source_mutex; delete gf; } -void acquire_weak_text_source_ref(struct transcription_filter_data *gf) -{ - if (!gf->text_source_name) { - obs_log(gf->log_level, "text_source_name is null"); - return; - } - - std::lock_guard lock(*gf->text_source_mutex); - - // acquire a weak ref to the new text source - obs_source_t *source = obs_get_source_by_name(gf->text_source_name); - if (source) { - gf->text_source = obs_source_get_weak_source(source); - obs_source_release(source); - if (!gf->text_source) { - obs_log(LOG_ERROR, "failed to get weak source for text source %s", - gf->text_source_name); - } - } else { - obs_log(LOG_ERROR, "text source '%s' not found", gf->text_source_name); - } -} - -#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) -#define is_trail_byte(c) (((c)&0xc0) == 0x80) - -inline int lead_byte_length(const uint8_t c) +void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy, + struct transcription_filter_data *gf) { - if ((c & 0xe0) == 0xc0) { - return 2; - } else if ((c & 0xf0) == 0xe0) { - return 3; - } else if ((c & 0xf8) == 0xf0) { - return 4; - } else { - return 1; - } -} - -inline bool is_valid_lead_byte(const uint8_t *c) -{ - const int length = lead_byte_length(c[0]); - if (length == 1) { - return true; - } - if (length == 2 && is_trail_byte(c[1])) { - return true; - } - if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { - return true; - } - if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { - return true; - } - return false; -} - -void send_caption_to_source(const std::string &str_copy, struct transcription_filter_data *gf) -{ - if (!gf->text_source_mutex) { - obs_log(LOG_ERROR, "text_source_mutex is null"); - return; - } - - if (!gf->text_source) { - // attempt to acquire a weak ref to the text source if it's yet available - acquire_weak_text_source_ref(gf); - } - - std::lock_guard lock(*gf->text_source_mutex); - - if (!gf->text_source) { - obs_log(gf->log_level, "text_source is null"); - return; - } - auto target = obs_weak_source_get_source(gf->text_source); + auto target = obs_get_source_by_name(target_source_name.c_str()); if (!target) { obs_log(gf->log_level, "text_source target is null"); return; @@ -267,52 +185,9 @@ void set_text_callback(struct transcription_filter_data *gf, } gf->last_sub_render_time = now; -#ifdef _WIN32 - // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs - // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. - std::stringstream ss; - uint8_t *c_str = (uint8_t *)result.text.c_str(); - for (size_t i = 0; i < result.text.size(); ++i) { - if (is_lead_byte(c_str[i])) { - // this is a unicode leading byte - // if the next char is 0xff - it's a bug char, replace it with 0x9f - if (c_str[i + 1] == 0xff) { - c_str[i + 1] = 0x9f; - } - if (!is_valid_lead_byte(c_str + i)) { - // This is a bug lead byte, because it's length 3 and the i+2 byte is also - // a lead byte - c_str[i] = c_str[i] - 0x20; - } - } else { - if (c_str[i] >= 0xf8) { - // this may be a malformed lead byte. - // lets see if it becomes a valid lead byte if we "fix" it - uint8_t buf_[4]; - buf_[0] = c_str[i] - 0x20; - buf_[1] = c_str[i + 1]; - buf_[2] = c_str[i + 2]; - buf_[3] = c_str[i + 3]; - if (is_valid_lead_byte(buf_)) { - // this is a malformed lead byte, fix it - c_str[i] = c_str[i] - 0x20; - } - } - } - } - - std::string str_copy = (char *)c_str; -#else - std::string str_copy = result.text; -#endif - - // remove trailing spaces, newlines, tabs or punctuation - str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(), - [](unsigned char ch) { - return !std::isspace(ch) || !std::ispunct(ch); - }) - .base(), - str_copy.end()); + // recondition the text + std::string str_copy = fix_utf8(result.text); + str_copy = remove_leading_trailing_nonalpha(str_copy); if (gf->translate) { obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(), @@ -324,7 +199,13 @@ void set_text_callback(struct transcription_filter_data *gf, obs_log(LOG_INFO, "Translation: '%s' -> '%s'", str_copy.c_str(), translated_text.c_str()); } - str_copy = translated_text; + if (gf->translation_output == "none") { + // overwrite the original text with the translated text + str_copy = translated_text; + } else { + // send the translation to the selected source + send_caption_to_source(gf->translation_output, translated_text, gf); + } } else { obs_log(gf->log_level, "Failed to translate text"); } @@ -333,7 +214,7 @@ void set_text_callback(struct transcription_filter_data *gf, gf->last_text = str_copy; if (gf->buffered_output) { - gf->captions_monitor.addWords(split_words(str_copy)); + gf->captions_monitor.addWords(result.tokens); } if (gf->caption_to_stream) { @@ -344,7 +225,7 @@ void set_text_callback(struct transcription_filter_data *gf, } } - if (gf->output_file_path != "" && !gf->text_source_name) { + if (gf->output_file_path != "" && gf->text_source_name.empty()) { // Check if we should save the sentence if (gf->save_only_while_recording && !obs_frontend_recording_active()) { // We are not recording, do not save the sentence to file @@ -396,7 +277,7 @@ void set_text_callback(struct transcription_filter_data *gf, } else { if (!gf->buffered_output) { // Send the caption to the text source - send_caption_to_source(str_copy, gf); + send_caption_to_source(gf->text_source_name, str_copy, gf); } } }; @@ -427,12 +308,21 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->process_while_muted = obs_data_get_bool(s, "process_while_muted"); gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); gf->last_sub_render_time = 0; - gf->buffered_output = obs_data_get_bool(s, "buffered_output"); + bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); + if (new_buffered_output != gf->buffered_output) { + gf->buffered_output = new_buffered_output; + gf->overlap_ms = gf->buffered_output ? MAX_OVERLAP_SIZE_MSEC + : DEFAULT_OVERLAP_SIZE_MSEC; + gf->overlap_frames = + (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms)); + } bool new_translate = obs_data_get_bool(s, "translate"); gf->source_lang = obs_data_get_string(s, "translate_source_language"); gf->target_lang = obs_data_get_string(s, "translate_target_language"); gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context"); + gf->translation_output = obs_data_get_string(s, "translate_output"); + gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences"); if (new_translate != gf->translate) { if (new_translate) { @@ -451,19 +341,7 @@ void transcription_filter_update(void *data, obs_data_t *s) strcmp(new_text_source_name, "(null)") == 0 || strcmp(new_text_source_name, "text_file") == 0 || strlen(new_text_source_name) == 0) { // new selected text source is not valid, release the old one - if (gf->text_source) { - if (!gf->text_source_mutex) { - obs_log(LOG_ERROR, "text_source_mutex is null"); - return; - } - std::lock_guard lock(*gf->text_source_mutex); - old_weak_text_source = gf->text_source; - gf->text_source = nullptr; - } - if (gf->text_source_name) { - bfree(gf->text_source_name); - gf->text_source_name = nullptr; - } + gf->text_source_name.clear(); gf->output_file_path = ""; if (strcmp(new_text_source_name, "text_file") == 0) { // set the output file path @@ -475,24 +353,9 @@ void transcription_filter_update(void *data, obs_data_t *s) } } else { // new selected text source is valid, check if it's different from the old one - if (gf->text_source_name == nullptr || - strcmp(new_text_source_name, gf->text_source_name) != 0) { + if (gf->text_source_name != new_text_source_name) { // new text source is different from the old one, release the old one - if (gf->text_source) { - if (!gf->text_source_mutex) { - obs_log(LOG_ERROR, "text_source_mutex is null"); - return; - } - std::lock_guard lock(*gf->text_source_mutex); - old_weak_text_source = gf->text_source; - gf->text_source = nullptr; - } - if (gf->text_source_name) { - // free the old text source name - bfree(gf->text_source_name); - gf->text_source_name = nullptr; - } - gf->text_source_name = bstrdup(new_text_source_name); + gf->text_source_name = new_text_source_name; } } @@ -507,7 +370,7 @@ void transcription_filter_update(void *data, obs_data_t *s) } obs_log(gf->log_level, "update whisper model"); - update_whsiper_model_path(gf, s); + update_whsiper_model(gf, s); obs_log(gf->log_level, "update whisper params"); std::lock_guard lock(*gf->whisper_ctx_mutex); @@ -597,15 +460,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) dst.format = AUDIO_FORMAT_FLOAT_PLANAR; dst.speakers = convert_speaker_layout((uint8_t)1); - gf->resampler = audio_resampler_create(&dst, &src); + gf->resampler_to_whisper = audio_resampler_create(&dst, &src); obs_log(gf->log_level, "setup mutexes and condition variables"); gf->whisper_buf_mutex = new std::mutex(); gf->whisper_ctx_mutex = new std::mutex(); gf->wshiper_thread_cv = new std::condition_variable(); - gf->text_source_mutex = new std::mutex(); obs_log(gf->log_level, "clear text source data"); - gf->text_source = nullptr; const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources"); if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 || strcmp(subtitle_sources, "(null)") == 0 || strlen(subtitle_sources) == 0) { @@ -619,8 +480,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) // create a new OBS text source called "LocalVocal Subtitles" obs_source_t *scene_as_source = obs_frontend_get_current_scene(); obs_scene_t *scene = obs_scene_from_source(scene_as_source); +#ifdef _WIN32 + source = obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles", + nullptr, nullptr); +#else source = obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles", nullptr, nullptr); +#endif if (source) { // add source to the current scene obs_scene_add(scene, source); @@ -660,11 +526,11 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) } obs_source_release(scene_as_source); } - gf->text_source_name = bstrdup("LocalVocal Subtitles"); + gf->text_source_name = "LocalVocal Subtitles"; obs_data_set_string(settings, "subtitle_sources", "LocalVocal Subtitles"); } else { // set the text source name - gf->text_source_name = bstrdup(subtitle_sources); + gf->text_source_name = subtitle_sources; } obs_log(gf->log_level, "clear paths and whisper context"); gf->whisper_model_file_currently_loaded = ""; @@ -673,13 +539,14 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->whisper_context = nullptr; gf->captions_monitor.initialize( + gf, [gf](const std::string &text) { obs_log(LOG_INFO, "Captions: %s", text.c_str()); if (gf->buffered_output) { - send_caption_to_source(text, gf); + send_caption_to_source(gf->text_source_name, text, gf); } }, - 20, std::chrono::seconds(10)); + 30, std::chrono::seconds(10)); obs_log(gf->log_level, "run update"); // get the settings updated on the filter data struct @@ -791,6 +658,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "translate_target_language", "__es__"); obs_data_set_default_string(s, "translate_source_language", "__en__"); obs_data_set_default_bool(s, "translate_add_context", true); + obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT); // Whisper parameters obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); @@ -805,6 +673,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_bool(s, "print_realtime", false); obs_data_set_default_bool(s, "print_timestamps", false); obs_data_set_default_bool(s, "token_timestamps", false); + obs_data_set_default_bool(s, "dtw_token_timestamps", false); obs_data_set_default_double(s, "thold_pt", 0.01); obs_data_set_default_double(s, "thold_ptsum", 0.01); obs_data_set_default_int(s, "max_len", 0); @@ -919,6 +788,15 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_list_add_string(prop_src, language.second.c_str(), language.first.c_str()); } + // add option for routing the translation to an output source + obs_property_t *prop_output = obs_properties_add_list(translation_group, "translate_output", + MT_("translate_output"), + OBS_COMBO_TYPE_LIST, + OBS_COMBO_FORMAT_STRING); + obs_property_list_add_string(prop_output, "Write to captions output", "none"); + // TODO add file output option + // obs_property_list_add_string(... + obs_enum_sources(add_sources_to_list, prop_output); // add callback to enable/disable translation group obs_property_set_modified_callback(translation_group_prop, [](obs_properties_t *props, @@ -928,7 +806,7 @@ obs_properties_t *transcription_filter_properties(void *data) // Show/Hide the translation group const bool translate_enabled = obs_data_get_bool(settings, "translate"); for (const auto &prop : {"translate_target_language", "translate_source_language", - "translate_add_context"}) { + "translate_add_context", "translate_output"}) { obs_property_set_visible(obs_properties_get(props, prop), translate_enabled); } @@ -946,21 +824,38 @@ obs_properties_t *transcription_filter_properties(void *data) for (const std::string &prop_name : {"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec", "overlap_size_msec", "step_by_step_processing", "min_sub_duration", - "process_while_muted", "buffered_output", "vad_enabled", "log_level"}) { + "process_while_muted", "buffered_output", "vad_enabled", "log_level", + "suppress_sentences"}) { obs_property_set_visible(obs_properties_get(props, prop_name.c_str()), show_hide); } return true; }); - obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output")); + obs_property_t *buffered_output_prop = + obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output")); + // add on-change handler for buffered_output + obs_property_set_modified_callback(buffered_output_prop, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + UNUSED_PARAMETER(props); + // if buffered output is enabled set the overlap to max else set it to default + obs_data_set_int(settings, "overlap_size_msec", + obs_data_get_bool(settings, "buffered_output") + ? MAX_OVERLAP_SIZE_MSEC + : DEFAULT_OVERLAP_SIZE_MSEC); + return true; + }); + obs_properties_add_bool(ppts, "log_words", MT_("log_words")); obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream")); obs_properties_add_int_slider(ppts, "buffer_size_msec", MT_("buffer_size_msec"), 1000, DEFAULT_BUFFER_SIZE_MSEC, 250); - obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), 50, 300, - 50); + obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), + MIN_OVERLAP_SIZE_MSEC, MAX_OVERLAP_SIZE_MSEC, + (MAX_OVERLAP_SIZE_MSEC - MIN_OVERLAP_SIZE_MSEC) / 5); obs_property_t *step_by_step_processing = obs_properties_add_bool( ppts, "step_by_step_processing", MT_("step_by_step_processing")); @@ -985,10 +880,14 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); - obs_property_list_add_int(list, "DEBUG", LOG_DEBUG); + obs_property_list_add_int(list, "DEBUG (Won't show)", LOG_DEBUG); obs_property_list_add_int(list, "INFO", LOG_INFO); obs_property_list_add_int(list, "WARNING", LOG_WARNING); + // add a text input for sentences to suppress + obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"), + OBS_TEXT_MULTILINE); + obs_properties_t *whisper_params_group = obs_properties_create(); obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"), OBS_GROUP_NORMAL, whisper_params_group); @@ -1043,6 +942,9 @@ obs_properties_t *transcription_filter_properties(void *data) obs_properties_add_bool(whisper_params_group, "print_timestamps", MT_("print_timestamps")); // bool token_timestamps; // enable token-level timestamps obs_properties_add_bool(whisper_params_group, "token_timestamps", MT_("token_timestamps")); + // enable DTW timestamps + obs_properties_add_bool(whisper_params_group, "dtw_token_timestamps", + MT_("dtw_token_timestamps")); // float thold_pt; // timestamp token probability threshold (~0.01) obs_properties_add_float_slider(whisper_params_group, "thold_pt", MT_("thold_pt"), 0.0f, 1.0f, 0.05f); diff --git a/src/transcription-filter.h b/src/transcription-filter.h index 6784540..b089510 100644 --- a/src/transcription-filter.h +++ b/src/transcription-filter.h @@ -19,6 +19,8 @@ const char *const PLUGIN_INFO_TEMPLATE = "OCC AI ❤️ " "Support & Follow"; +const char *const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you"; + #ifdef __cplusplus } #endif diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp new file mode 100644 index 0000000..c7f9d40 --- /dev/null +++ b/src/transcription-utils.cpp @@ -0,0 +1,104 @@ +#include "transcription-utils.h" + +#include +#include + +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + +std::string fix_utf8(const std::string &str) +{ +#ifdef _WIN32 + // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)str.c_str(); + for (size_t i = 0; i < str.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } + } + } + + return std::string((char *)c_str); +#else + return str; +#endif +} + +/* +* Remove leading and trailing non-alphabetic characters from a string. +* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation. +* @param str: the string to remove leading and trailing non-alphabetic characters from. +* @return: the string with leading and trailing non-alphabetic characters removed. +*/ +std::string remove_leading_trailing_nonalpha(const std::string &str) +{ + std::string str_copy = str; + // remove trailing spaces, newlines, tabs or punctuation + str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(), + [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }) + .base(), + str_copy.end()); + // remove leading spaces, newlines, tabs or punctuation + str_copy.erase(str_copy.begin(), + std::find_if(str_copy.begin(), str_copy.end(), [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + })); + return str_copy; +} diff --git a/src/transcription-utils.h b/src/transcription-utils.h new file mode 100644 index 0000000..5e2e500 --- /dev/null +++ b/src/transcription-utils.h @@ -0,0 +1,9 @@ +#ifndef TRANSCRIPTION_UTILS_H +#define TRANSCRIPTION_UTILS_H + +#include + +std::string fix_utf8(const std::string &str); +std::string remove_leading_trailing_nonalpha(const std::string &str); + +#endif // TRANSCRIPTION_UTILS_H diff --git a/src/whisper-utils/silero-vad-onnx.cpp b/src/whisper-utils/silero-vad-onnx.cpp index 437e6e7..2951245 100644 --- a/src/whisper-utils/silero-vad-onnx.cpp +++ b/src/whisper-utils/silero-vad-onnx.cpp @@ -10,7 +10,10 @@ #include #include -//#define __DEBUG_SPEECH_PROB___ +#include +#include "plugin-support.h" + +// #define __DEBUG_SPEECH_PROB___ timestamp_t::timestamp_t(int start_, int end_) : start(start_), end(end_){}; @@ -144,8 +147,8 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, - current_sample - window_size_samples); + obs_log(LOG_INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ if (temp_end != 0) { temp_end = 0; @@ -194,16 +197,18 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ speeking: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, - speech_prob, current_sample - window_size_samples); + obs_log(LOG_INFO, "{ speaking: %.3f s (%.3f) %08d}", + 1.0 * speech / sample_rate, speech_prob, + current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ } else { #ifdef __DEBUG_SPEECH_PROB___ float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, - speech_prob, current_sample - window_size_samples); + obs_log(LOG_INFO, "{ silence: %.3f s (%.3f) %08d}", + 1.0 * speech / sample_rate, speech_prob, + current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ } return; @@ -215,8 +220,8 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. - printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, - current_sample - window_size_samples); + obs_log(LOG_INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ if (triggered == true) { if (temp_end == 0) { @@ -285,7 +290,7 @@ void VadIterator::collect_chunks(const std::vector &input_wav, output_wav.clear(); for (size_t i = 0; i < speeches.size(); i++) { #ifdef __DEBUG_SPEECH_PROB___ - std::cout << speeches[i].c_str() << std::endl; + obs_log(LOG_INFO, "%s", speeches[i].string().c_str()); #endif //#ifdef __DEBUG_SPEECH_PROB___ std::vector slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]); diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp new file mode 100644 index 0000000..13d2ffc --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -0,0 +1,131 @@ +#include "token-buffer-thread.h" +#include "./whisper-utils.h" + +TokenBufferThread::~TokenBufferThread() +{ + { + std::lock_guard lock(queueMutex); + stop = true; + } + condVar.notify_all(); + workerThread.join(); +} + +void TokenBufferThread::initialize(struct transcription_filter_data *gf_, + std::function callback_, + size_t maxSize_, std::chrono::seconds maxTime_) +{ + this->gf = gf_; + this->callback = callback_; + this->maxSize = maxSize_; + this->maxTime = maxTime_; + this->initialized = true; + this->workerThread = std::thread(&TokenBufferThread::monitor, this); +} + +void TokenBufferThread::log_token_vector(const std::vector &tokens) +{ + std::string output; + for (const auto &token : tokens) { + const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); + output += token_str; + } + obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); +} + +void TokenBufferThread::addWords(const std::vector &words) +{ + obs_log(LOG_INFO, "TokenBufferThread::addWords"); + { + std::lock_guard lock(queueMutex); + + // convert current wordQueue to vector + std::vector currentWords(wordQueue.begin(), wordQueue.end()); + + log_token_vector(currentWords); + log_token_vector(words); + + // run reconstructSentence + std::vector reconstructed = + reconstructSentence(currentWords, words); + + log_token_vector(reconstructed); + + // clear the wordQueue + wordQueue.clear(); + + // add the reconstructed sentence to the wordQueue + for (const auto &word : reconstructed) { + wordQueue.push_back(word); + } + + newDataAvailable = true; + } + condVar.notify_all(); +} + +void TokenBufferThread::monitor() +{ + obs_log(LOG_INFO, "TokenBufferThread::monitor"); + auto startTime = std::chrono::steady_clock::now(); + while (this->initialized && !this->stop) { + std::unique_lock lock(this->queueMutex); + // wait for new data or stop signal + this->condVar.wait(lock, [this] { return this->newDataAvailable || this->stop; }); + + if (this->stop) { + break; + } + + if (this->wordQueue.empty()) { + continue; + } + + if (this->gf->whisper_context == nullptr) { + continue; + } + + // emit up to maxSize words from the wordQueue + std::vector emitted; + while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) { + emitted.push_back(this->wordQueue.front()); + this->wordQueue.pop_front(); + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: emitting %d words", emitted.size()); + log_token_vector(emitted); + // emit the caption from the tokens + std::string output; + for (const auto &token : emitted) { + const char *token_str = + whisper_token_to_str(this->gf->whisper_context, token.id); + output += token_str; + } + this->callback(output); + // push back the words that were emitted, in reverse order + for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) { + this->wordQueue.push_front(*it); + } + + // check if we need to flush the queue + auto elapsedTime = std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime); + if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) { + // flush the queue if it's full or we've reached the max time + size_t words_to_flush = std::min(this->wordQueue.size(), this->maxSize); + // make sure we leave at least 3 words in the queue + size_t words_remaining = this->wordQueue.size() - words_to_flush; + if (words_remaining < 3) { + words_to_flush -= 3 - words_remaining; + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words", + words_to_flush); + for (size_t i = 0; i < words_to_flush; ++i) { + wordQueue.pop_front(); + } + startTime = std::chrono::steady_clock::now(); + } + + newDataAvailable = false; + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); +} diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h new file mode 100644 index 0000000..1a56b70 --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.h @@ -0,0 +1,49 @@ +#ifndef TOKEN_BUFFER_THREAD_H +#define TOKEN_BUFFER_THREAD_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include "plugin-support.h" + +struct transcription_filter_data; + +class TokenBufferThread { +public: + // default constructor + TokenBufferThread() = default; + + ~TokenBufferThread(); + void initialize(struct transcription_filter_data *gf, + std::function callback_, size_t maxSize_, + std::chrono::seconds maxTime_); + + void addWords(const std::vector &words); + +private: + void monitor(); + void log_token_vector(const std::vector &tokens); + struct transcription_filter_data *gf; + std::deque wordQueue; + std::thread workerThread; + std::mutex queueMutex; + std::condition_variable condVar; + std::function callback; + size_t maxSize; + std::chrono::seconds maxTime; + bool stop; + bool initialized = false; + bool newDataAvailable = false; +}; + +#endif diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 807cb54..9970619 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -5,6 +5,7 @@ #include "plugin-support.h" #include "transcription-filter-data.h" #include "whisper-processing.h" +#include "whisper-utils.h" #include #include @@ -109,7 +110,8 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v return true; } -struct whisper_context *init_whisper_context(const std::string &model_path_in) +struct whisper_context *init_whisper_context(const std::string &model_path_in, + struct transcription_filter_data *gf) { std::string model_path = model_path_in; @@ -131,14 +133,15 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in) whisper_log_set( [](enum ggml_log_level level, const char *text, void *user_data) { UNUSED_PARAMETER(level); - UNUSED_PARAMETER(user_data); + struct transcription_filter_data *ctx = + static_cast(user_data); // remove trailing newline char *text_copy = bstrdup(text); text_copy[strcspn(text_copy, "\n")] = 0; - obs_log(LOG_INFO, "Whisper: %s", text_copy); + obs_log(ctx->log_level, "Whisper: %s", text_copy); bfree(text_copy); }, - nullptr); + gf); struct whisper_context_params cparams = whisper_context_default_params(); #ifdef LOCALVOCAL_WITH_CUDA @@ -152,6 +155,16 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in) obs_log(LOG_INFO, "Using CPU for inference"); #endif + cparams.dtw_token_timestamps = gf->enable_token_ts_dtw; + if (gf->enable_token_ts_dtw) { + obs_log(LOG_INFO, "DTW token timestamps enabled"); + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + // cparams.dtw_n_top = 4; + } else { + obs_log(LOG_INFO, "DTW token timestamps disabled"); + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + } + struct whisper_context *ctx = nullptr; try { #ifdef _WIN32 @@ -196,16 +209,19 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in) } struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf, - const float *pcm32f_data, size_t pcm32f_size) + const float *pcm32f_data, size_t pcm32f_size, + bool zero_start) { + UNUSED_PARAMETER(zero_start); + if (gf == nullptr) { obs_log(LOG_ERROR, "run_whisper_inference: gf is null"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } if (pcm32f_data == nullptr || pcm32f_size == 0) { obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, @@ -215,7 +231,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter std::lock_guard lock(*gf->whisper_ctx_mutex); if (gf->whisper_context == nullptr) { obs_log(LOG_WARNING, "whisper context is null"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } // Get the duration in ms since the beginning of the stream (gf->start_timestamp_ms) @@ -234,47 +250,92 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what()); whisper_free(gf->whisper_context); gf->whisper_context = nullptr; - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } if (whisper_full_result != 0) { obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } else { // duration in ms const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); const int n_segment = 0; - const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment); + // const char *text = whisper_full_get_segment_text(gf->whisper_context, n_segment); const int64_t t0 = offset_ms; const int64_t t1 = offset_ms + duration_ms; float sentence_p = 0.0f; const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment); + std::string text = ""; + std::string tokenIds = ""; + std::vector tokens; + bool end = false; for (int j = 0; j < n_tokens; ++j) { sentence_p += whisper_full_get_token_p(gf->whisper_context, n_segment, j); + // get token + whisper_token_data token = + whisper_full_get_token_data(gf->whisper_context, n_segment, j); + const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); + bool keep = !end; + // if the token starts with '[' and ends with ']', don't keep it + if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') { + keep = false; + } + if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) { + keep = false; + } + // if the second to last token is .id == 13 ('.'), don't keep it + if (j == n_tokens - 2 && token.id == 13) { + keep = false; + } + // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json + if (token.id > 50540 && token.id <= 51865) { + obs_log(gf->log_level, + "Large time token found (%d), this shouldn't happen", + token.id); + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; + } + + if (keep) { + text += token_str; + tokenIds += std::to_string(token.id) + " (" + + std::string(token_str) + "), "; + tokens.push_back(token); + } + obs_log(gf->log_level, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j, + token.id, token_str, token.p, token.t_dtw, keep); } sentence_p /= (float)n_tokens; - - // convert text to lowercase - std::string text_lower(text); - std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower); - // trim whitespace (use lambda) - text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - text_lower.end()); + obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); + obs_log(gf->log_level, "Token IDs: %s", tokenIds.c_str()); + + // if suppression is enabled, check if the text is in the suppression list + if (!gf->suppress_sentences.empty()) { + std::string suppress_sentences_copy = gf->suppress_sentences; + size_t pos = 0; + std::string token; + while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) { + token = suppress_sentences_copy.substr(0, pos); + suppress_sentences_copy.erase(0, pos + 1); + if (text == suppress_sentences_copy) { + obs_log(gf->log_level, "Suppressing sentence: %s", + text.c_str()); + return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}}; + } + } + } if (gf->log_words) { obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), - to_timestamp(t1).c_str(), sentence_p, text_lower.c_str()); + to_timestamp(t1).c_str(), sentence_p, text.c_str()); } - if (text_lower.empty() || text_lower == ".") { - return {DETECTION_RESULT_SILENCE, "", 0, 0}; + if (text.empty() || text == ".") { + return {DETECTION_RESULT_SILENCE, "", 0, 0, {}}; } - return {DETECTION_RESULT_SPEECH, text_lower, offset_ms, offset_ms + duration_ms}; + return {DETECTION_RESULT_SPEECH, text, offset_ms, offset_ms + duration_ms, tokens}; } } @@ -307,14 +368,14 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) num_new_frames_from_infos -= info_from_buf.frames; circlebuf_push_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); - last_step_in_segment = - true; // this is the final step in the segment + // this is the final step in the segment + last_step_in_segment = true; break; } } obs_log(gf->log_level, - "with %lu remaining to full segment, popped %d info-frames, pushing into buffer at %lu", + "with %lu remaining to full segment, popped %d info-frames, pushing at %lu (overlap)", remaining_frames_to_full_segment, num_new_frames_from_infos, gf->last_num_frames); @@ -340,7 +401,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) } } else { gf->last_num_frames = num_new_frames_from_infos; - obs_log(gf->log_level, "first segment, %d frames to process", + obs_log(gf->log_level, "first segment, no overlap exists, %d frames to process", (int)(gf->last_num_frames)); } @@ -352,51 +413,92 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) auto start = std::chrono::high_resolution_clock::now(); // resample to 16kHz - float *output[MAX_PREPROC_CHANNELS]; - uint32_t out_frames; + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; uint64_t ts_offset; - audio_resampler_resample(gf->resampler, (uint8_t **)output, &out_frames, &ts_offset, + audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, (const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames); - obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, (int)out_frames, - (float)out_frames / WHISPER_SAMPLE_RATE * 1000.0f); + obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, + (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); bool skipped_inference = false; uint32_t speech_start_frame = 0; - uint32_t speech_end_frame = out_frames; + uint32_t speech_end_frame = resampled_16khz_frames; if (gf->vad_enabled) { - std::vector vad_input(output[0], output[0] + out_frames); + std::vector vad_input(resampled_16khz[0], + resampled_16khz[0] + resampled_16khz_frames); gf->vad->process(vad_input); - auto stamps = gf->vad->get_speech_timestamps(); + std::vector stamps = gf->vad->get_speech_timestamps(); if (stamps.size() == 0) { + obs_log(gf->log_level, "VAD detected no speech in %d frames", + resampled_16khz_frames); skipped_inference = true; + // prevent copying the buffer to the beginning (overlap) + gf->last_num_frames = 0; + last_step_in_segment = false; } else { - speech_start_frame = stamps[0].start; + speech_start_frame = (stamps[0].start < 3000) ? 0 : stamps[0].start; speech_end_frame = stamps.back().end; - obs_log(gf->log_level, "VAD detected speech from %d to %d", - speech_start_frame, speech_end_frame); + uint32_t number_of_frames = speech_end_frame - speech_start_frame; + + obs_log(gf->log_level, + "VAD detected speech from %d to %d (%d frames, %d ms)", + speech_start_frame, speech_end_frame, number_of_frames, + number_of_frames * 1000 / WHISPER_SAMPLE_RATE); + + // if the speech segment is less than 1 second - put the audio back into the buffer + // to be handled in the next iteration + if (number_of_frames > 0 && number_of_frames < WHISPER_SAMPLE_RATE) { + // convert speech_start_frame and speech_end_frame to original sample rate + speech_start_frame = + speech_start_frame * gf->sample_rate / WHISPER_SAMPLE_RATE; + speech_end_frame = + speech_end_frame * gf->sample_rate / WHISPER_SAMPLE_RATE; + number_of_frames = speech_end_frame - speech_start_frame; + + // use memmove to copy the speech segment to the beginning of the buffer + for (size_t c = 0; c < gf->channels; c++) { + memmove(gf->copy_buffers[c], + gf->copy_buffers[c] + speech_start_frame, + number_of_frames * sizeof(float)); + } + + obs_log(gf->log_level, + "Speech segment is less than 1 second, moving %d to %d (len %d) to buffer start", + speech_start_frame, speech_end_frame, number_of_frames); + // no processing of the segment + skipped_inference = true; + // reset the last_num_frames to the number of frames in the buffer + gf->last_num_frames = number_of_frames; + // prevent copying the buffer to the beginning (overlap) + last_step_in_segment = false; + } } } if (!skipped_inference) { // run inference const struct DetectionResultWithText inference_result = run_whisper_inference( - gf, output[0] + speech_start_frame, speech_end_frame - speech_start_frame); + gf, resampled_16khz[0] + speech_start_frame, + speech_end_frame - speech_start_frame, speech_start_frame == 0); if (inference_result.result == DETECTION_RESULT_SPEECH) { // output inference result to a text source set_text_callback(gf, inference_result); } else if (inference_result.result == DETECTION_RESULT_SILENCE) { // output inference result to a text source - set_text_callback(gf, {inference_result.result, "[silence]", 0, 0}); + set_text_callback(gf, {inference_result.result, "[silence]", 0, 0, {}}); } } else { if (gf->log_words) { obs_log(LOG_INFO, "skipping inference"); } - set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0}); + set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0, {}}); } // end of timer @@ -407,6 +509,12 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) (int)duration); if (last_step_in_segment) { + const uint64_t overlap_size_ms = + (uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate); + obs_log(gf->log_level, + "copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning", + gf->overlap_frames, overlap_size_ms, + gf->last_num_frames - gf->overlap_frames); for (size_t c = 0; c < gf->channels; c++) { // This is the last step in the segment - reset the copy buffer (include overlap frames) // move overlap frames from the end of the last copy_buffers to the beginning @@ -416,8 +524,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) // zero out the rest of the buffer, just in case memset(gf->copy_buffers[c] + gf->overlap_frames, 0, (gf->frames - gf->overlap_frames) * sizeof(float)); - gf->last_num_frames = gf->overlap_frames; } + gf->last_num_frames = gf->overlap_frames; } } diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h index 3e189fe..6798e92 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/whisper-utils/whisper-processing.h @@ -1,12 +1,32 @@ #ifndef WHISPER_PROCESSING_H #define WHISPER_PROCESSING_H +#include + // buffer size in msec -#define DEFAULT_BUFFER_SIZE_MSEC 2000 +#define DEFAULT_BUFFER_SIZE_MSEC 3000 // overlap in msec #define DEFAULT_OVERLAP_SIZE_MSEC 100 +#define MAX_OVERLAP_SIZE_MSEC 1000 +#define MIN_OVERLAP_SIZE_MSEC 100 + +enum DetectionResult { + DETECTION_RESULT_UNKNOWN = 0, + DETECTION_RESULT_SILENCE = 1, + DETECTION_RESULT_SPEECH = 2, + DETECTION_RESULT_SUPPRESSED = 3, +}; + +struct DetectionResultWithText { + DetectionResult result; + std::string text; + uint64_t start_timestamp_ms; + uint64_t end_timestamp_ms; + std::vector tokens; +}; void whisper_loop(void *data); -struct whisper_context *init_whisper_context(const std::string &model_path); +struct whisper_context *init_whisper_context(const std::string &model_path, + struct transcription_filter_data *gf); #endif // WHISPER_PROCESSING_H diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index fad35ad..ad619f8 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -5,7 +5,7 @@ #include -void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s) +void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s) { // update the whisper model path std::string new_model_path = obs_data_get_string(s, "whisper_model_path"); @@ -13,9 +13,12 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path || is_external_model) { - // model path changed, reload the model - obs_log(gf->log_level, "model path changed from %s to %s", - gf->whisper_model_path.c_str(), new_model_path.c_str()); + + if (gf->whisper_model_path != new_model_path) { + // model path changed + obs_log(gf->log_level, "model path changed from %s to %s", + gf->whisper_model_path.c_str(), new_model_path.c_str()); + } // check if the new model is external file if (!is_external_model) { @@ -76,6 +79,21 @@ void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t obs_log(gf->log_level, "Model path did not change: %s == %s", gf->whisper_model_path.c_str(), new_model_path.c_str()); } + + const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps"); + + if (new_dtw_timestamps != gf->enable_token_ts_dtw) { + // dtw_token_timestamps changed + obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d", + gf->enable_token_ts_dtw, new_dtw_timestamps); + gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps"); + shutdown_whisper_thread(gf); + start_whisper_thread_with_path(gf, gf->whisper_model_path); + } else { + // dtw_token_timestamps did not change + obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d", + gf->enable_token_ts_dtw, new_dtw_timestamps); + } } void shutdown_whisper_thread(struct transcription_filter_data *gf) @@ -122,9 +140,12 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const #else std::string silero_vad_model_path = silero_vad_model_file; #endif - gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000, + 200, 250)); - gf->whisper_context = init_whisper_context(path); + gf->whisper_context = init_whisper_context(path, gf); if (gf->whisper_context == nullptr) { obs_log(LOG_ERROR, "Failed to initialize whisper context"); return; @@ -133,3 +154,100 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } + +// Finds start of 2-token overlap between two sequences of tokens +// Returns a pair of indices of the first overlapping tokens in the two sequences +// If no overlap is found, the function returns {-1, -1} +// Allows for a single token mismatch in the overlap +std::pair findStartOfOverlap(const std::vector &seq1, + const std::vector &seq2) +{ + if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) { + return {-1, -1}; + } + for (size_t i = seq1.size() - 2; i >= seq1.size() / 2; --i) { + for (size_t j = 0; j < seq2.size() - 1; ++j) { + if (seq1[i].id == seq2[j].id) { + // Check if the next token in both sequences is the same + if (seq1[i + 1].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq1 + if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq2 + if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) { + return {i, j}; + } + } + } + } + return {-1, -1}; +} + +// Function to reconstruct a whole sentence from two sentences using overlap info +// If no overlap is found, the function returns the concatenation of the two sequences +std::vector reconstructSentence(const std::vector &seq1, + const std::vector &seq2) +{ + auto overlap = findStartOfOverlap(seq1, seq2); + std::vector reconstructed; + + if (overlap.first == -1 || overlap.second == -1) { + if (seq1.empty() && seq2.empty()) { + return reconstructed; + } + if (seq1.empty()) { + return seq2; + } + if (seq2.empty()) { + return seq1; + } + + // Return concat of seq1 and seq2 if no overlap found + // check if the last token of seq1 == the first token of seq2 + if (seq1.back().id == seq2.front().id) { + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } else if (seq2.size() > 1ull && seq1.back().id == seq2[1].id) { + // check if the last token of seq1 == the second token of seq2 + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + // don't add the first token of seq2 + reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end()); + } else if (seq1.size() > 1ull && seq1[seq1.size() - 2].id == seq2.front().id) { + // check if the second to last token of seq1 == the first token of seq2 + // don't add the last two tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } else { + // add all tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end()); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } + return reconstructed; + } + + // Add tokens from the first sequence up to the overlap + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.begin() + overlap.first); + + // Determine the length of the overlap + size_t overlapLength = 0; + while (overlap.first + overlapLength < seq1.size() && + overlap.second + overlapLength < seq2.size() && + seq1[overlap.first + overlapLength].id == seq2[overlap.second + overlapLength].id) { + overlapLength++; + } + + // Add overlapping tokens + reconstructed.insert(reconstructed.end(), seq1.begin() + overlap.first, + seq1.begin() + overlap.first + overlapLength); + + // Add remaining tokens from the second sequence + reconstructed.insert(reconstructed.end(), seq2.begin() + overlap.second + overlapLength, + seq2.end()); + + return reconstructed; +} diff --git a/src/whisper-utils/whisper-utils.h b/src/whisper-utils/whisper-utils.h index 6e80b2f..bc941f8 100644 --- a/src/whisper-utils/whisper-utils.h +++ b/src/whisper-utils/whisper-utils.h @@ -7,8 +7,13 @@ #include -void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s); +void update_whsiper_model(struct transcription_filter_data *gf, obs_data_t *s); void shutdown_whisper_thread(struct transcription_filter_data *gf); void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path); +std::pair findStartOfOverlap(const std::vector &seq1, + const std::vector &seq2); +std::vector reconstructSentence(const std::vector &seq1, + const std::vector &seq2); + #endif /* WHISPER_UTILS_H */