From 39cfad0dee803d04756b34850020662f390ca45c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 5 Nov 2023 19:43:45 +0200 Subject: [PATCH] whisper : add support for new distilled Whisper models (#1424) * whisper : add support for new distilled Whisper models * whisper : print log when using distilled models --- whisper.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index 17ef4d9e8ab..3e36d362054 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3940,6 +3940,7 @@ static void whisper_process_logits( // suppress task tokens logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); @@ -4558,6 +4559,7 @@ int whisper_full_with_state( // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; @@ -4569,6 +4571,17 @@ int whisper_full_with_state( } } + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2; + + // distilled models require the "no_timestamps" token + // TODO: add input parameter (#1229) + if (is_distil) { + log("%s: using distilled model - forcing no_timestamps\n", __func__); + prompt_init.push_back(whisper_token_not(ctx)); + } + } + int seek = seek_start; std::vector prompt;