From 7f0dc9bd6a7106dcfc033341f89e68a75b9ea5b1 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Mon, 19 Jun 2023 09:43:23 -0700 Subject: [PATCH 01/15] add HuggingFace mirror to download ggml model --- models/download-ggml-model.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index 4440b94eb14..23dba76f5f9 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -22,7 +22,7 @@ function get_script_path() { models_path="$(get_script_path)" # Whisper models -models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" ) +models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" ) # list available models function list_models { @@ -50,6 +50,12 @@ if [[ ! " ${models[@]} " =~ " ${model} " ]]; then exit 1 fi +# check if model contains `tdrz` and update the src and pfx accordingly +if [[ $model == *"tdrz"* ]]; then + src="https://huggingface.co/akashmjn/tinydiarize-whisper.cpp" + pfx="resolve/main/ggml" +fi + # download ggml model printf "Downloading ggml model $model from '$src' ...\n" From 62c851bf825a5682444e94a84da61b028e7c44b1 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Mon, 19 Jun 2023 09:54:19 -0700 Subject: [PATCH 02/15] support tdrz via simple hack overriding solm tokens --- whisper.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 5f3888c7916..58ffca341c2 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -382,14 +382,14 @@ struct whisper_vocab { id token_eot = 50256; id token_sot = 50257; + id token_solm = 50359; // ?? TODO@Akash - rename appropriately id token_prev = 50360; - id token_solm = 50361; // ?? id token_not = 50362; // no timestamps - id token_beg = 50363; + id token_beg = 50363; // begin timestamps // available tasks - static const id token_translate = 50358; - static const id token_transcribe = 50359; + static const id token_translate = 50358; // TODO@Akash - technically it's 50357 for .en models + static const id token_transcribe = 50359; // TODO@Akash - technically it's 50358 for .en models bool is_multilingual() const { return n_vocab == 51865; @@ -3521,7 +3521,7 @@ static void whisper_process_logits( // suppress sot and solm tokens logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_solm] = -INFINITY; + // logits[vocab.token_solm] = -INFINITY; // suppress task tokens logits[vocab.token_translate] = -INFINITY; @@ -4500,7 +4500,6 @@ int whisper_full_with_state( prompt_past.push_back(tokens_cur[i].id); } - // store the text from this iteration if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); @@ -4517,6 +4516,10 @@ int whisper_full_with_state( text += whisper_token_to_str(ctx, tokens_cur[i].id); } + if (tokens_cur[i].id == whisper_token_solm(ctx)){ + text += " [SPEAKER TURN]"; + }; + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); From c8e1ed6f8281e83abc8cbd54483c24c767bcde66 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Mon, 19 Jun 2023 15:53:47 -0700 Subject: [PATCH 03/15] fix incorrect translate/transcribe token_ids that are not static const --- bindings/go/whisper.go | 4 +- .../whispercpp/WhisperCppJnaLibrary.java | 4 +- whisper.cpp | 37 ++++++++++--------- whisper.h | 4 +- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 8a5efa7de0c..d2ea756679e 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -271,12 +271,12 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token { // Task tokens func Whisper_token_translate() Token { - return Token(C.whisper_token_translate()) + return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx))) } // Task tokens func Whisper_token_transcribe() Token { - return Token(C.whisper_token_transcribe()) + return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx))) } // Performance information diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index c1fb4f8e3b0..ad9faa0be70 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library { int whisper_token_lang(Pointer ctx, int lang_id); // Task tokens - int whisper_token_translate(); - int whisper_token_transcribe(); + int whisper_token_translate (Pointer ctx); + int whisper_token_transcribe(Pointer ctx); // Performance information from the default state. void whisper_print_timings(Pointer ctx); diff --git a/whisper.cpp b/whisper.cpp index 58ffca341c2..79112b896d3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -380,16 +380,17 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - id token_eot = 50256; - id token_sot = 50257; - id token_solm = 50359; // ?? TODO@Akash - rename appropriately - id token_prev = 50360; - id token_not = 50362; // no timestamps - id token_beg = 50363; // begin timestamps - - // available tasks - static const id token_translate = 50358; // TODO@Akash - technically it's 50357 for .en models - static const id token_transcribe = 50359; // TODO@Akash - technically it's 50358 for .en models + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // ?? TODO@Akash - rename appropriately + id token_prev = 50360; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps bool is_multilingual() const { return n_vocab == 51865; @@ -966,8 +967,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (vocab.is_multilingual()) { vocab.token_eot++; vocab.token_sot++; - vocab.token_prev++; + vocab.token_translate++; + vocab.token_transcribe++; vocab.token_solm++; + vocab.token_prev++; vocab.token_not++; vocab.token_beg++; } @@ -3228,12 +3231,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { return whisper_token_sot(ctx) + 1 + lang_id; } -whisper_token whisper_token_translate(void) { - return whisper_vocab::token_translate; +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; } -whisper_token whisper_token_transcribe(void) { - return whisper_vocab::token_transcribe; +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; } void whisper_print_timings(struct whisper_context * ctx) { @@ -4018,9 +4021,9 @@ int whisper_full_with_state( state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { - prompt_init.push_back(whisper_token_translate()); + prompt_init.push_back(whisper_token_translate(ctx)); } else { - prompt_init.push_back(whisper_token_transcribe()); + prompt_init.push_back(whisper_token_transcribe(ctx)); } } diff --git a/whisper.h b/whisper.h index e983c7d4fa3..6525b47df03 100644 --- a/whisper.h +++ b/whisper.h @@ -284,8 +284,8 @@ extern "C" { WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx); From 700c2829134ca737798eca55cd9905988d9a49b7 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Tue, 20 Jun 2023 11:21:35 -0700 Subject: [PATCH 04/15] add apollo 13 sample for tdrz demo --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index c7b05a9a6d9..73aa15ef528 100644 --- a/Makefile +++ b/Makefile @@ -302,12 +302,16 @@ samples: @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg @wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav + @wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3 @echo "Converting to 16-bit WAV ..." @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav + @rm samples/*.ogg @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav @rm samples/mm1.wav + @ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav + @rm samples/a13.mp3 # # Models From 4083a393ddc0e23754ff900e7cb74262b4e4d1b1 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Mon, 26 Jun 2023 04:25:37 -0700 Subject: [PATCH 05/15] render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token --- examples/main/main.cpp | 2 +- whisper.cpp | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ff62f74b887..16d67645ca6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -279,7 +279,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { + if (id >= whisper_token_eot(ctx) and id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable? continue; } } diff --git a/whisper.cpp b/whisper.cpp index 79112b896d3..a1f30043f84 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -984,6 +984,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_solm) { // TODO@Akash make this configurable + word = " [SPEAKER TURN]"; } else if (i == vocab.token_prev) { word = "[_PREV_]"; } else if (i == vocab.token_not) { @@ -4514,15 +4516,12 @@ int whisper_full_with_state( // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx) && + tokens_cur[i].id != whisper_token_solm(ctx)) { } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } - if (tokens_cur[i].id == whisper_token_solm(ctx)){ - text += " [SPEAKER TURN]"; - }; - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); From 713c5b61ad024f3da296539f25af95c97028dbad Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Tue, 27 Jun 2023 10:28:53 -0700 Subject: [PATCH 06/15] extend whisper_segment with speaker_turn_next field and save in json output --- examples/main/main.cpp | 5 ++++- whisper.cpp | 19 ++++++++++++++++--- whisper.h | 3 +++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 16d67645ca6..5ea7ca09530 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -566,6 +566,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(ctx, i); start_obj(nullptr); start_obj("timestamps"); @@ -576,11 +577,13 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper value_i("from", t0 * 10, false); value_i("to", t1 * 10, true); end_obj(false); - value_s("text", text, !params.diarize); + value_s("text", text, !params.diarize); // TODO@Akash - make configurable with flag if (params.diarize && pcmf32s.size() == 2) { value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); } + // TODO@Akash - make configurable with flag + value_b("speaker_turn_next", speaker_turn_next, true); end_obj(i == (n_segments - 1)); } diff --git a/whisper.cpp b/whisper.cpp index a1f30043f84..d562423f79c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -404,6 +404,8 @@ struct whisper_segment { std::string text; std::vector tokens; + + bool speaker_turn_next; }; // medium @@ -4510,6 +4512,7 @@ int whisper_full_with_state( auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; + bool speaker_turn_next; for (int i = 0; i < (int) tokens_cur.size(); i++) { //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, @@ -4517,11 +4520,16 @@ int whisper_full_with_state( // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx) && - tokens_cur[i].id != whisper_token_solm(ctx)) { + tokens_cur[i].id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable with flag (may not want it in text) } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } + // record if speaker turn was predicted after current segment + if (tokens_cur[i].id == whisper_token_solm(ctx)){ + speaker_turn_next = true; + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); @@ -4540,7 +4548,7 @@ int whisper_full_with_state( //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4566,6 +4574,7 @@ int whisper_full_with_state( i--; t0 = t1; i0 = i + 1; + speaker_turn_next = false; } } @@ -4584,7 +4593,7 @@ int whisper_full_with_state( } } - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4764,6 +4773,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) return ctx->state->result_all[i_segment].t1; } +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].speaker_turn_next; +} + const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { return state->result_all[i_segment].text.c_str(); } diff --git a/whisper.h b/whisper.h index 6525b47df03..b3450f790d4 100644 --- a/whisper.h +++ b/whisper.h @@ -460,6 +460,9 @@ extern "C" { WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + // Get the text of the specified segment WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); From edd23488de95eb384d96699890b118c9c4febbee Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Tue, 27 Jun 2023 11:26:06 -0700 Subject: [PATCH 07/15] fix failing go build --- bindings/go/whisper.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index d2ea756679e..e605d8e0c85 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -270,12 +270,12 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token { } // Task tokens -func Whisper_token_translate() Token { +func (ctx *Context) Whisper_token_translate() Token { return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx))) } // Task tokens -func Whisper_token_transcribe() Token { +func (ctx *Context) Whisper_token_transcribe() Token { return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx))) } From 77825ecfff081f873eb42f3a5440f1d35da31f10 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Tue, 27 Jun 2023 19:18:33 -0700 Subject: [PATCH 08/15] slipped in some python syntax whoops --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5ea7ca09530..6c74a3df3ff 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -279,7 +279,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx) and id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable? + if (id >= whisper_token_eot(ctx) && id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable? continue; } } From 59e905513109e403d56bf232d25cffece35ba3a0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 20:24:57 +0300 Subject: [PATCH 09/15] whisper : finalize tinydiarize support (add flag + fixes) --- examples/main/main.cpp | 142 ++++++++++++++++++++++++----------------- whisper.cpp | 119 +++++++++++++++++++--------------- whisper.h | 10 ++- 3 files changed, 155 insertions(+), 116 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6c74a3df3ff..2ba7e6991eb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -68,28 +68,32 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; - bool speed_up = false; - bool translate = false; - bool detect_language= false; - bool diarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool output_txt = false; - bool output_vtt = false; - bool output_srt = false; - bool output_wts = false; - bool output_csv = false; - bool output_jsn = false; - bool output_lrc = false; - bool print_special = false; - bool print_colors = false; - bool print_progress = false; - bool no_timestamps = false; - - std::string language = "en"; + bool speed_up = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool output_csv = false; + bool output_jsn = false; + bool output_lrc = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + + std::string language = "en"; std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; + std::string model = "models/ggml-base.en.bin"; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line std::vector fname_inp = {}; std::vector fname_out = {}; @@ -115,41 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -297,6 +302,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper printf("%s%s", speaker.c_str(), text); } + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + // with timestamps or speakers: each segment on new line if (!params.no_timestamps || params.diarize) { printf("\n"); @@ -564,9 +575,9 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(ctx, i); start_obj(nullptr); start_obj("timestamps"); @@ -577,13 +588,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper value_i("from", t0 * 10, false); value_i("to", t1 * 10, true); end_obj(false); - value_s("text", text, !params.diarize); // TODO@Akash - make configurable with flag + value_s("text", text, !params.diarize && !params.tinydiarize); if (params.diarize && pcmf32s.size() == 2) { value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); } - // TODO@Akash - make configurable with flag - value_b("speaker_turn_next", speaker_turn_next, true); + + if (params.tinydiarize) { + value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true); + } end_obj(i == (n_segments - 1)); } @@ -780,6 +793,12 @@ int main(int argc, char ** argv) { exit(0); } + if (params.diarize && params.tinydiarize) { + fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + // whisper init struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); @@ -821,11 +840,12 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", + params.tinydiarize ? "tdrz = 1, " : "", params.no_timestamps ? 0 : 1); fprintf(stderr, "\n"); @@ -856,6 +876,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; diff --git a/whisper.cpp b/whisper.cpp index 494cc4b5387..14ced531f9b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -387,8 +387,9 @@ struct whisper_vocab { id token_translate = 50357; id token_transcribe = 50358; // other special tokens - id token_solm = 50359; // ?? TODO@Akash - rename appropriately + id token_tdrz = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn id token_prev = 50360; + id token_solm = 50361; // start of lm ? id token_not = 50362; // no timestamps id token_beg = 50363; // begin timestamps @@ -971,8 +972,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.token_sot++; vocab.token_translate++; vocab.token_transcribe++; - vocab.token_solm++; + vocab.token_tdrz++; vocab.token_prev++; + vocab.token_solm++; vocab.token_not++; vocab.token_beg++; } @@ -986,10 +988,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; - } else if (i == vocab.token_solm) { // TODO@Akash make this configurable - word = " [SPEAKER TURN]"; + } else if (i == vocab.token_tdrz) { + word = "[_TDRZ_]"; } else if (i == vocab.token_prev) { word = "[_PREV_]"; + } else if (i == vocab.token_solm) { + word = "[_SOLM_]"; } else if (i == vocab.token_not) { word = "[_NOT_]"; } else if (i == vocab.token_beg) { @@ -3215,6 +3219,10 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) { return ctx->vocab.token_sot; } +whisper_token whisper_token_tdrz(struct whisper_context * ctx) { + return ctx->vocab.token_tdrz; +} + whisper_token whisper_token_prev(struct whisper_context * ctx) { return ctx->vocab.token_prev; } @@ -3312,51 +3320,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result = { - /*.strategy =*/ strategy, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ true, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.split_on_word =*/ false, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - /*.detect_language =*/ false, - - /*.suppress_blank =*/ true, + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ true, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.tdrz_enable =*/ false, + + /*.initial_prompt =*/ nullptr, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + /*.detect_language =*/ false, + + /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ false, - /*.temperature =*/ 0.0f, - /*.max_initial_ts =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, - /*.entropy_thold =*/ 2.4f, - /*.logprob_thold =*/ -1.0f, - /*.no_speech_thold =*/ 0.6f, + /*.temperature_inc =*/ 0.4f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, - /*.greedy =*/ { + /*.greedy =*/ { /*.best_of =*/ -1, }, @@ -3528,7 +3538,12 @@ static void whisper_process_logits( // suppress sot and solm tokens logits[vocab.token_sot] = -INFINITY; - // logits[vocab.token_solm] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; + + // [TDRZ] when tinydiarize is disabled, suppress tdrz token + if (params.tdrz_enable == false) { + logits[vocab.token_tdrz] = -INFINITY; + } // suppress task tokens logits[vocab.token_translate] = -INFINITY; @@ -4512,21 +4527,19 @@ int whisper_full_with_state( auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; - bool speaker_turn_next; + bool speaker_turn_next = false; for (int i = 0; i < (int) tokens_cur.size(); i++) { //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx) && - tokens_cur[i].id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable with flag (may not want it in text) - } else { + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { text += whisper_token_to_str(ctx, tokens_cur[i].id); } - // record if speaker turn was predicted after current segment - if (tokens_cur[i].id == whisper_token_solm(ctx)){ + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_tdrz(ctx)) { speaker_turn_next = true; } @@ -4548,7 +4561,7 @@ int whisper_full_with_state( //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); + result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } diff --git a/whisper.h b/whisper.h index b3450f790d4..300cac65f0a 100644 --- a/whisper.h +++ b/whisper.h @@ -277,6 +277,7 @@ extern "C" { // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_tdrz(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); @@ -358,6 +359,9 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const char * initial_prompt; @@ -491,9 +495,9 @@ extern "C" { // Temporary helpers needed for exposing ggml interface - WHISPER_API int whisper_bench_memcpy(int n_threads); - WHISPER_API const char * whisper_bench_memcpy_str(int n_threads); - WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); #ifdef __cplusplus From 8ee5af481c452410024b4e9351162ebbf121b43b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 20:29:05 +0300 Subject: [PATCH 10/15] whisper : tdrz support for word-level timestamps (respect max_len) --- whisper.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index 14ced531f9b..c0757a3eb37 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3447,6 +3447,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta state.result_all.back().text = std::move(text); state.result_all.back().t1 = token.t0; state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; state.result_all.push_back({}); state.result_all.back().t0 = token.t0; @@ -3458,6 +3459,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta segment.tokens.begin() + i, segment.tokens.end()); + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + acc = 0; text = ""; From 5fa32daf20470a1d241aced26f9aba522da0f802 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 20:44:33 +0300 Subject: [PATCH 11/15] java : try to fix tests after adding tdrz_enable flag --- .../ggerganov/whispercpp/params/WhisperFullParams.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 07e68948ef8..3a4ec2b9f08 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -137,6 +137,14 @@ public void speedUp(boolean enable) { /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; + /** Enable tinydiarize (default = false) */ + public CBool tdrz_enable; + + /** Enable tinydiarize (default = false) */ + public void tdrzEnable(boolean enable) { + tdrz_enable = enable ? CBool.TRUE : CBool.FALSE; + } + /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; From 6828be73fef643db52f1272b23301199d139ea82 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 20:51:09 +0300 Subject: [PATCH 12/15] main : remove TODO leftover --- examples/main/main.cpp | 74 +++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2ba7e6991eb..ce4bf9826db 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -119,42 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -284,7 +284,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { if (params.print_special == false) { const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx) && id != whisper_token_solm(ctx)) { // TODO@Akash - make configurable? + if (id >= whisper_token_eot(ctx)) { continue; } } From 09c32a6834ff51c3dfc152815211385513bd7154 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 21:01:16 +0300 Subject: [PATCH 13/15] java : fix params order list after adding "tdrz_enable" --- .../github/ggerganov/whispercpp/params/WhisperFullParams.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 3a4ec2b9f08..7765561eca8 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -310,7 +310,7 @@ protected List getFieldOrder() { "no_context", "single_segment", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", - "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "new_segment_callback", "new_segment_callback_user_data", From baabe2c9b2993ef79a64f14183b5a74ffbb0fb87 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Jul 2023 21:42:02 +0300 Subject: [PATCH 14/15] whisper : fix solm and add nosp token --- whisper.cpp | 34 +++++++++++++++++----------------- whisper.h | 4 ++-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index c0757a3eb37..fb489b38453 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -387,9 +387,9 @@ struct whisper_vocab { id token_translate = 50357; id token_transcribe = 50358; // other special tokens - id token_tdrz = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn id token_prev = 50360; - id token_solm = 50361; // start of lm ? + id token_nosp = 50361; id token_not = 50362; // no timestamps id token_beg = 50363; // begin timestamps @@ -972,9 +972,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.token_sot++; vocab.token_translate++; vocab.token_transcribe++; - vocab.token_tdrz++; - vocab.token_prev++; vocab.token_solm++; + vocab.token_prev++; + vocab.token_nosp++; vocab.token_not++; vocab.token_beg++; } @@ -988,12 +988,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; - } else if (i == vocab.token_tdrz) { - word = "[_TDRZ_]"; - } else if (i == vocab.token_prev) { - word = "[_PREV_]"; } else if (i == vocab.token_solm) { word = "[_SOLM_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_nosp) { + word = "[_NOSP_]"; } else if (i == vocab.token_not) { word = "[_NOT_]"; } else if (i == vocab.token_beg) { @@ -3219,16 +3219,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) { return ctx->vocab.token_sot; } -whisper_token whisper_token_tdrz(struct whisper_context * ctx) { - return ctx->vocab.token_tdrz; +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; } whisper_token whisper_token_prev(struct whisper_context * ctx) { return ctx->vocab.token_prev; } -whisper_token whisper_token_solm(struct whisper_context * ctx) { - return ctx->vocab.token_solm; +whisper_token whisper_token_nosp(struct whisper_context * ctx) { + return ctx->vocab.token_nosp; } whisper_token whisper_token_not(struct whisper_context * ctx) { @@ -3539,13 +3539,13 @@ static void whisper_process_logits( // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; - // suppress sot and solm tokens + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_solm] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now - // [TDRZ] when tinydiarize is disabled, suppress tdrz token + // [TDRZ] when tinydiarize is disabled, suppress solm token if (params.tdrz_enable == false) { - logits[vocab.token_tdrz] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; } // suppress task tokens @@ -4542,7 +4542,7 @@ int whisper_full_with_state( } // [TDRZ] record if speaker turn was predicted after current segment - if (params.tdrz_enable && tokens_cur[i].id == whisper_token_tdrz(ctx)) { + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { speaker_turn_next = true; } diff --git a/whisper.h b/whisper.h index 300cac65f0a..c08723bbb2b 100644 --- a/whisper.h +++ b/whisper.h @@ -277,9 +277,9 @@ extern "C" { // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_tdrz(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); From 34bdf9854393e4e8f17fd47e4194b2f1b7fa5af3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 4 Jul 2023 09:44:16 +0300 Subject: [PATCH 15/15] main : print tinydiarize help --- examples/main/main.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ce4bf9826db..344b6877882 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -187,6 +187,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");