From 8a46034af7cbed1c35edac89d62f8d5773507fef Mon Sep 17 00:00:00 2001 From: bobqianic <129547291+bobqianic@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:36:51 +0000 Subject: [PATCH 1/2] Add files via upload --- whisper.cpp | 204 +++++++++++++++++++++++++++++++++++++--------------- whisper.h | 11 ++- 2 files changed, 156 insertions(+), 59 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index aebbb5295a8..ac31b792d20 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -402,6 +402,8 @@ struct whisper_segment { std::vector tokens; + double no_speech_probs; + bool speaker_turn_next; }; @@ -4525,7 +4527,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.thold_pt =*/ 0.01f, /*.thold_ptsum =*/ 0.01f, /*.max_len =*/ 0, - /*.split_on_word =*/ false, /*.max_tokens =*/ 0, /*.speed_up =*/ false, @@ -4542,7 +4543,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.detect_language =*/ false, /*.suppress_blank =*/ true, - /*.suppress_non_speech_tokens =*/ false, + /*.suppress_non_speech_tokens =*/ true, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -4613,65 +4614,144 @@ static void whisper_exp_compute_token_level_timestamps( float thold_pt, float thold_ptsum); -//static inline bool should_split_on_word(const char * txt, bool split_on_word) { -// if (!split_on_word) return true; -// -// return txt[0] == ' '; -//} +static bool whisper_utf8_is_valid(const std::string &str) { + uint64_t count = 0; // Count of bytes in the current UTF-8 character + + for (unsigned char c : str) { + if (count == 0) { + if ((c >> 5) == 0b110) count = 1; // 2-byte character + else if ((c >> 4) == 0b1110) count = 2; // 3-byte character + else if ((c >> 3) == 0b11110) count = 3; // 4-byte character + else if ((c >> 7) == 0b0) count = 0; // 1-byte character + else return false; // Invalid UTF-8 + } else { + if ((c >> 6) != 0b10) return false; // Subsequent bytes should start with 10 + count--; + } + } + + return count == 0; // Ensure all UTF-8 characters are complete +} + +static bool whisper_utf8_is_valid(const char * str) { + std::string new_str(str); + return whisper_utf8_is_valid(new_str); +} + +static std::vector> whisper_utf8_merge_and_split(const std::string &str) { + std::vector> result; + std::string buffer; + uint64_t count = 0; // Count of bytes in the current UTF-8 character + + for (unsigned char c : str) { + if (count == 0) { + header: + if ((c >> 5) == 0b110) count = 1; // 2-byte character + else if ((c >> 4) == 0b1110) count = 2; // 3-byte character + else if ((c >> 3) == 0b11110) count = 3; // 4-byte character + else count = 0; // Invalid UTF-8 || 1-byte character + if (!buffer.empty()) result.emplace_back(buffer, true); + buffer.clear(); + buffer += static_cast(c); + } else { + if ((c >> 6) != 0b10) { + goto header; + } // Subsequent bytes should start with 10 + buffer += static_cast(c); + count--; + } + } + + if (!buffer.empty()) result.emplace_back(buffer, false); + return result; +} + +static std::vector whisper_split_tokens_on_utf8(struct whisper_context & ctx, whisper_segment & segment, bool special) { + std::vector words; + + std::string text; + std::vector raw; + int64_t t0 = -1; + int64_t t1 = -1; + + for (const auto & token : segment.tokens) { + if (special == false && token.id >= whisper_token_beg(&ctx)) { + continue; + } + if (t0 < 0) {t0 = token.t0;} + t1 = token.t1; + text += whisper_token_to_str(&ctx, token.id); + raw.push_back(token); + + if (whisper_utf8_is_valid(text)) { + words.push_back({t0, t1, text, raw, segment.no_speech_probs, segment.speaker_turn_next}); + t0 = -1; + t1 = -1; + raw.clear(); + text = ""; + } + } + + return words; +} // wrap the last segment into segments with max_len number of words // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool special) { + const static std::set unicode_language = {"zh", "ja", "th", "lo", "my", "yue"}; + const static std::string punctuation = R"(!"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~)"; + auto segment = state.result_all.back(); - int res = 1; - int acc = 0; + std::vector words; - std::string text; + if (unicode_language.find(whisper_lang_str(ctx.state->lang_id)) != unicode_language.end()) { + // split on utf-8 + words = whisper_split_tokens_on_utf8(ctx, segment, special); + } else { + // split on spaces and punctuation + auto subwords = whisper_split_tokens_on_utf8(ctx, segment, special); -// for (int i = 0; i < (int) segment.tokens.size(); i++) { -// const auto & token = segment.tokens[i]; -// if (token.id >= whisper_token_eot(&ctx)) { -// continue; -// } -// -// const auto txt = whisper_token_to_str(&ctx, token.id); -// const int cur = strlen(txt); -// -// if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { -// state.result_all.back().text = std::move(text); -// state.result_all.back().t1 = token.t0; -// state.result_all.back().tokens.resize(i); -// state.result_all.back().speaker_turn_next = false; -// -// state.result_all.push_back({}); -// state.result_all.back().t0 = token.t0; -// state.result_all.back().t1 = segment.t1; -// -// // add tokens [i, end] to the new segment -// state.result_all.back().tokens.insert( -// state.result_all.back().tokens.end(), -// segment.tokens.begin() + i, -// segment.tokens.end()); -// -// state.result_all.back().speaker_turn_next = segment.speaker_turn_next; -// -// acc = 0; -// text = ""; -// -// segment = state.result_all.back(); -// i = -1; -// -// res++; -// } else { -// acc += cur; -// text += txt; -// } -// } -// -// state.result_all.back().text = std::move(text); + for (auto & subword : subwords) { + if (subword.tokens[0].id >= whisper_token_beg(&ctx) || subword.text[0] == ' ' || punctuation.find(subword.text) != std::string::npos) { + words.push_back(subword); + } else { + words.back().t1 = subword.t1; + words.back().text += subword.text; + words.back().tokens.insert(words.back().tokens.end(), subword.tokens.begin(), subword.tokens.end()); + } + } + } - return res; + state.result_all.pop_back(); + + if (max_len == 1) { + state.result_all.insert(state.result_all.end(), words.begin(), words.end()); + return static_cast(words.size()); + } else { + int acc = 0; + int n_new = 0; + whisper_segment temp = {}; + + for (auto & word : words) { + if (acc == 0) {temp.t0 = word.t0;} + temp.t1 = word.t1; + temp.text += word.text; + temp.tokens.insert(temp.tokens.end(), word.tokens.begin(), word.tokens.end()); + temp.speaker_turn_next = word.speaker_turn_next; + temp.no_speech_probs = word.no_speech_probs; + + if (acc + 1 >= max_len) { + state.result_all.push_back(temp); + temp = {}; + acc = 0; + n_new ++; + } else { + acc++; + } + } + return n_new; + } } // ref: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/decoding.py#L689-L693 @@ -5927,6 +6007,7 @@ int whisper_full_with_state( const auto seek_delta = best_decoder.seek_delta; const auto result_len = best_decoder.sequence.result_len; + const auto non_speech_probs = best_decoder.sequence.no_speech_probs; const auto & tokens_cur = best_decoder.sequence.tokens; @@ -5965,18 +6046,19 @@ int whisper_full_with_state( auto text_callback = [&](int t1, int token_offset, int end) { int n_new = 1; - result_all.push_back({ t0, t1, text, {} , speaker_turn_next }); + result_all.push_back({ t0, t1, text, {}, non_speech_probs, speaker_turn_next }); for (int j = std::max(0, token_offset); j <= end; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps(*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + } - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); - } + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.print_special); } + if (params.new_segment_callback) { params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } @@ -6192,6 +6274,14 @@ int whisper_full_lang_id(struct whisper_context * ctx) { return ctx->state->lang_id; } +double whisper_full_get_segment_no_speech_probs_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].no_speech_probs; +} + +double whisper_full_get_segment_no_speech_probs(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].no_speech_probs; +} + int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { return state->result_all[i_segment].t0; } diff --git a/whisper.h b/whisper.h index a321e76044a..f9bdf36a0ce 100644 --- a/whisper.h +++ b/whisper.h @@ -451,8 +451,7 @@ extern "C" { bool token_timestamps; // enable token-level timestamps float thold_pt; // timestamp token probability threshold (~0.01) float thold_ptsum; // timestamp token sum probability threshold (~0.01) - int max_len; // max segment length in characters - bool split_on_word; // split on word rather than on token (when used with max_len) + int max_len; // max segment length in characters (0 = no limit) int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques @@ -570,6 +569,10 @@ extern "C" { // Language id associated with the provided state WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + // Get the no speech probability of the specified segment + WHISPER_API double whisper_full_get_segment_no_speech_probs (struct whisper_context * ctx, int i_segment); + WHISPER_API double whisper_full_get_segment_no_speech_probs_from_state(struct whisper_state * state, int i_segment); + // Get the start and end time of the specified segment WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); @@ -605,6 +608,10 @@ extern "C" { WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + // Check if the string is valid UTF-8 + WHISPER_API bool whisper_utf8_is_valid(const char * str); + + //////////////////////////////////////////////////////////////////////////// // Temporary helpers needed for exposing ggml interface From 7a5a2e9a3aea4ec2682226d83c50686a2296efe6 Mon Sep 17 00:00:00 2001 From: bobqianic <129547291+bobqianic@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:37:27 +0000 Subject: [PATCH 2/2] Add files via upload --- examples/common.cpp | 50 -------------------------------------- examples/common.h | 5 ---- examples/main/main.cpp | 37 ++++++++++++---------------- examples/server/server.cpp | 8 ------ 4 files changed, 16 insertions(+), 84 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index ee9bd47df77..3dd2a248d86 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -616,56 +616,6 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat( } - -namespace utf_8 { - bool is_valid(const std::string &str) { - uint64_t count = 0; // Count of bytes in the current UTF-8 character - - for (unsigned char c : str) { - if (count == 0) { - if ((c >> 5) == 0b110) count = 1; // 2-byte character - else if ((c >> 4) == 0b1110) count = 2; // 3-byte character - else if ((c >> 3) == 0b11110) count = 3; // 4-byte character - else if ((c >> 7) == 0b0) count = 0; // 1-byte character - else return false; // Invalid UTF-8 - } else { - if ((c >> 6) != 0b10) return false; // Subsequent bytes should start with 10 - count--; - } - } - - return count == 0; // Ensure all UTF-8 characters are complete - } - - std::vector merge_and_split(const std::string &str) { - std::vector result; - std::string buffer; - uint64_t count = 0; // Count of bytes in the current UTF-8 character - - for (unsigned char c : str) { - if (count == 0) { - header: - if ((c >> 5) == 0b110) count = 1; // 2-byte character - else if ((c >> 4) == 0b1110) count = 2; // 3-byte character - else if ((c >> 3) == 0b11110) count = 3; // 4-byte character - else count = 0; // Invalid UTF-8 || 1-byte character - if (!buffer.empty()) result.push_back(buffer); - buffer.clear(); - buffer += static_cast(c); - } else { - if ((c >> 6) != 0b10) { - goto header; - } // Subsequent bytes should start with 10 - buffer += static_cast(c); - count--; - } - } - - if (!buffer.empty()) result.push_back(buffer); - return result; - } -} - bool is_wav_buffer(const std::string buf) { // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html diff --git a/examples/common.h b/examples/common.h index dd6b85f9693..09094a1b8a1 100644 --- a/examples/common.h +++ b/examples/common.h @@ -131,11 +131,6 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat( float repeat_penalty, std::mt19937 & rng); -namespace utf_8{ - bool is_valid(const std::string &str); - std::vector merge_and_split(const std::string &str); -} - // // Audio utils // diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b4114f6c910..591a09f8f57 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -77,7 +77,6 @@ struct whisper_params { bool detect_language = false; bool diarize = false; bool tinydiarize = false; - bool split_on_word = false; bool no_fallback = false; bool output_txt = false; bool output_vtt = false; @@ -149,7 +148,6 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } @@ -197,7 +195,6 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); @@ -320,6 +317,10 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper if (params.print_colors) { + std::string buffer; + float probability_sum = 0; + int count = 0; + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); @@ -328,26 +329,21 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); -// if (utf_8::is_valid(text)) { -// printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); -// } else { - printf("%s[_%i_]%s", k_colors[col].c_str(), whisper_full_get_token_id(ctx, i, j), "\033[0m"); -// } + buffer += whisper_full_get_token_text(ctx, i, j); + probability_sum += whisper_full_get_token_p (ctx, i, j); + count++; + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(probability_sum/static_cast(count), 3)*float(k_colors.size())))); + + if (whisper_utf8_is_valid(buffer.c_str())) { + printf("%s%s%s", k_colors[col].c_str(), buffer.c_str(), "\033[0m"); + buffer.clear(); + probability_sum = 0; + count = 0; + } } } else { const char * text = whisper_full_get_segment_text(ctx, i); - for (auto &k : utf_8::merge_and_split(text)) { - if (utf_8::is_valid(k)) { - printf("%s", k.c_str()); - } else { - for (auto l : k) { - printf("[_%i_]", l); - } - } - } + printf("%s", text); } if (params.tinydiarize) { @@ -1016,7 +1012,6 @@ int run(int argc, const char ** argv) { wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0; wparams.thold_pt = params.word_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; wparams.debug_mode = params.debug_mode; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 69c04bf3a0a..b9c566a3a68 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -73,7 +73,6 @@ struct whisper_params { bool detect_language = false; bool diarize = false; bool tinydiarize = false; - bool split_on_word = false; bool no_fallback = false; bool print_special = false; bool print_colors = false; @@ -136,7 +135,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); @@ -192,7 +190,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } @@ -462,10 +459,6 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content); } - if (req.has_file("split_on_word")) - { - params.split_on_word = parse_str_to_bool(req.get_file_value("split_on_word").content); - } if (req.has_file("no_timestamps")) { params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content); @@ -738,7 +731,6 @@ int main(int argc, char ** argv) { wparams.thold_pt = params.word_thold; wparams.max_len = params.max_len == 0 ? 60 : params.max_len; - wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; wparams.debug_mode = params.debug_mode;