From c8d0f5fe9801862bdd7f63a949937a804d02cfb5 Mon Sep 17 00:00:00 2001 From: Akash Mahajan Date: Mon, 3 Jul 2023 23:45:00 -0700 Subject: [PATCH] whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058) * add HuggingFace mirror to download ggml model * support tdrz via simple hack overriding solm tokens * fix incorrect translate/transcribe token_ids that are not static const * add apollo 13 sample for tdrz demo * render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token * extend whisper_segment with speaker_turn_next field and save in json output * fix failing go build * slipped in some python syntax whoops * whisper : finalize tinydiarize support (add flag + fixes) * whisper : tdrz support for word-level timestamps (respect max_len) * java : try to fix tests after adding tdrz_enable flag * main : remove TODO leftover * java : fix params order list after adding "tdrz_enable" * whisper : fix solm and add nosp token * main : print tinydiarize help --------- Co-authored-by: Georgi Gerganov --- Makefile | 4 + bindings/go/whisper.go | 8 +- .../whispercpp/WhisperCppJnaLibrary.java | 4 +- .../whispercpp/params/WhisperFullParams.java | 10 +- examples/main/main.cpp | 140 +++++++++------ models/download-ggml-model.sh | 8 +- whisper.cpp | 168 +++++++++++------- whisper.h | 19 +- 8 files changed, 223 insertions(+), 138 deletions(-) diff --git a/Makefile b/Makefile index 045f711c696..caab8f3cdeb 100644 --- a/Makefile +++ b/Makefile @@ -308,12 +308,16 @@ samples: @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg @wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav + @wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3 @echo "Converting to 16-bit WAV ..." @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav + @rm samples/*.ogg @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav @rm samples/mm1.wav + @ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav + @rm samples/a13.mp3 # # Models diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 8a5efa7de0c..e605d8e0c85 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -270,13 +270,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token { } // Task tokens -func Whisper_token_translate() Token { - return Token(C.whisper_token_translate()) +func (ctx *Context) Whisper_token_translate() Token { + return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx))) } // Task tokens -func Whisper_token_transcribe() Token { - return Token(C.whisper_token_transcribe()) +func (ctx *Context) Whisper_token_transcribe() Token { + return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx))) } // Performance information diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index c1fb4f8e3b0..ad9faa0be70 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library { int whisper_token_lang(Pointer ctx, int lang_id); // Task tokens - int whisper_token_translate(); - int whisper_token_transcribe(); + int whisper_token_translate (Pointer ctx); + int whisper_token_transcribe(Pointer ctx); // Performance information from the default state. void whisper_print_timings(Pointer ctx); diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 07e68948ef8..7765561eca8 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -137,6 +137,14 @@ public void speedUp(boolean enable) { /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; + /** Enable tinydiarize (default = false) */ + public CBool tdrz_enable; + + /** Enable tinydiarize (default = false) */ + public void tdrzEnable(boolean enable) { + tdrz_enable = enable ? CBool.TRUE : CBool.FALSE; + } + /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; @@ -302,7 +310,7 @@ protected List getFieldOrder() { "no_context", "single_segment", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", - "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "new_segment_callback", "new_segment_callback_user_data", diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ff62f74b887..344b6877882 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -68,28 +68,32 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; - bool speed_up = false; - bool translate = false; - bool detect_language= false; - bool diarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool output_txt = false; - bool output_vtt = false; - bool output_srt = false; - bool output_wts = false; - bool output_csv = false; - bool output_jsn = false; - bool output_lrc = false; - bool print_special = false; - bool print_colors = false; - bool print_progress = false; - bool no_timestamps = false; - - std::string language = "en"; + bool speed_up = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool output_csv = false; + bool output_jsn = false; + bool output_lrc = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + + std::string language = "en"; std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; + std::string model = "models/ggml-base.en.bin"; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line std::vector fname_inp = {}; std::vector fname_out = {}; @@ -115,41 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -182,6 +187,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); @@ -297,6 +303,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper printf("%s%s", speaker.c_str(), text); } + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + // with timestamps or speakers: each segment on new line if (!params.no_timestamps || params.diarize) { printf("\n"); @@ -564,6 +576,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); @@ -576,11 +589,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper value_i("from", t0 * 10, false); value_i("to", t1 * 10, true); end_obj(false); - value_s("text", text, !params.diarize); + value_s("text", text, !params.diarize && !params.tinydiarize); if (params.diarize && pcmf32s.size() == 2) { value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); } + + if (params.tinydiarize) { + value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true); + } end_obj(i == (n_segments - 1)); } @@ -777,6 +794,12 @@ int main(int argc, char ** argv) { exit(0); } + if (params.diarize && params.tinydiarize) { + fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + // whisper init struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); @@ -818,11 +841,12 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", + params.tinydiarize ? "tdrz = 1, " : "", params.no_timestamps ? 0 : 1); fprintf(stderr, "\n"); @@ -853,6 +877,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index 4440b94eb14..23dba76f5f9 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -22,7 +22,7 @@ function get_script_path() { models_path="$(get_script_path)" # Whisper models -models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" ) +models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" ) # list available models function list_models { @@ -50,6 +50,12 @@ if [[ ! " ${models[@]} " =~ " ${model} " ]]; then exit 1 fi +# check if model contains `tdrz` and update the src and pfx accordingly +if [[ $model == *"tdrz"* ]]; then + src="https://huggingface.co/akashmjn/tinydiarize-whisper.cpp" + pfx="resolve/main/ggml" +fi + # download ggml model printf "Downloading ggml model $model from '$src' ...\n" diff --git a/whisper.cpp b/whisper.cpp index 932ae6fe32d..fb489b38453 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -380,16 +380,18 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - id token_eot = 50256; - id token_sot = 50257; - id token_prev = 50360; - id token_solm = 50361; // ?? - id token_not = 50362; // no timestamps - id token_beg = 50363; - - // available tasks - static const id token_translate = 50358; - static const id token_transcribe = 50359; + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps bool is_multilingual() const { return n_vocab == 51865; @@ -403,6 +405,8 @@ struct whisper_segment { std::string text; std::vector tokens; + + bool speaker_turn_next; }; // medium @@ -966,8 +970,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (vocab.is_multilingual()) { vocab.token_eot++; vocab.token_sot++; - vocab.token_prev++; + vocab.token_translate++; + vocab.token_transcribe++; vocab.token_solm++; + vocab.token_prev++; + vocab.token_nosp++; vocab.token_not++; vocab.token_beg++; } @@ -981,8 +988,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_solm) { + word = "[_SOLM_]"; } else if (i == vocab.token_prev) { word = "[_PREV_]"; + } else if (i == vocab.token_nosp) { + word = "[_NOSP_]"; } else if (i == vocab.token_not) { word = "[_NOT_]"; } else if (i == vocab.token_beg) { @@ -3208,12 +3219,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) { return ctx->vocab.token_sot; } +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + whisper_token whisper_token_prev(struct whisper_context * ctx) { return ctx->vocab.token_prev; } -whisper_token whisper_token_solm(struct whisper_context * ctx) { - return ctx->vocab.token_solm; +whisper_token whisper_token_nosp(struct whisper_context * ctx) { + return ctx->vocab.token_nosp; } whisper_token whisper_token_not(struct whisper_context * ctx) { @@ -3228,12 +3243,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { return whisper_token_sot(ctx) + 1 + lang_id; } -whisper_token whisper_token_translate(void) { - return whisper_vocab::token_translate; +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; } -whisper_token whisper_token_transcribe(void) { - return whisper_vocab::token_transcribe; +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; } void whisper_print_timings(struct whisper_context * ctx) { @@ -3305,51 +3320,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result = { - /*.strategy =*/ strategy, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ true, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.split_on_word =*/ false, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - /*.detect_language =*/ false, - - /*.suppress_blank =*/ true, + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ true, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.tdrz_enable =*/ false, + + /*.initial_prompt =*/ nullptr, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + /*.detect_language =*/ false, + + /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ false, - /*.temperature =*/ 0.0f, - /*.max_initial_ts =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, - /*.entropy_thold =*/ 2.4f, - /*.logprob_thold =*/ -1.0f, - /*.no_speech_thold =*/ 0.6f, + /*.temperature_inc =*/ 0.4f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, - /*.greedy =*/ { + /*.greedy =*/ { /*.best_of =*/ -1, }, @@ -3430,6 +3447,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta state.result_all.back().text = std::move(text); state.result_all.back().t1 = token.t0; state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; state.result_all.push_back({}); state.result_all.back().t0 = token.t0; @@ -3441,6 +3459,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta segment.tokens.begin() + i, segment.tokens.end()); + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + acc = 0; text = ""; @@ -3519,9 +3539,14 @@ static void whisper_process_logits( // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; - // suppress sot and solm tokens + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_solm] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + + // [TDRZ] when tinydiarize is disabled, suppress solm token + if (params.tdrz_enable == false) { + logits[vocab.token_solm] = -INFINITY; + } // suppress task tokens logits[vocab.token_translate] = -INFINITY; @@ -4018,9 +4043,9 @@ int whisper_full_with_state( state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { - prompt_init.push_back(whisper_token_translate()); + prompt_init.push_back(whisper_token_translate(ctx)); } else { - prompt_init.push_back(whisper_token_transcribe()); + prompt_init.push_back(whisper_token_transcribe(ctx)); } } @@ -4500,23 +4525,27 @@ int whisper_full_with_state( prompt_past.push_back(tokens_cur[i].id); } - // store the text from this iteration if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; + bool speaker_turn_next = false; for (int i = 0; i < (int) tokens_cur.size(); i++) { //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { text += whisper_token_to_str(ctx, tokens_cur[i].id); } + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + speaker_turn_next = true; + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); @@ -4535,7 +4564,7 @@ int whisper_full_with_state( //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4561,6 +4590,7 @@ int whisper_full_with_state( i--; t0 = t1; i0 = i + 1; + speaker_turn_next = false; } } @@ -4579,7 +4609,7 @@ int whisper_full_with_state( } } - result_all.push_back({ tt0, tt1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -4759,6 +4789,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) return ctx->state->result_all[i_segment].t1; } +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].speaker_turn_next; +} + const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { return state->result_all[i_segment].text.c_str(); } diff --git a/whisper.h b/whisper.h index e983c7d4fa3..c08723bbb2b 100644 --- a/whisper.h +++ b/whisper.h @@ -277,15 +277,16 @@ extern "C" { // Special tokens WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx); @@ -358,6 +359,9 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const char * initial_prompt; @@ -460,6 +464,9 @@ extern "C" { WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + // Get the text of the specified segment WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); @@ -488,9 +495,9 @@ extern "C" { // Temporary helpers needed for exposing ggml interface - WHISPER_API int whisper_bench_memcpy(int n_threads); - WHISPER_API const char * whisper_bench_memcpy_str(int n_threads); - WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); #ifdef __cplusplus