Skip to content

Commit

Permalink
whisper : add support for new distilled Whisper models (#1424)
Browse files Browse the repository at this point in the history
* whisper : add support for new distilled Whisper models

* whisper : print log when using distilled models
  • Loading branch information
ggerganov committed Nov 5, 2023
1 parent 6d4d0b5 commit 39cfad0
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3940,6 +3940,7 @@ static void whisper_process_logits(
// suppress task tokens
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;

if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
Expand Down Expand Up @@ -4558,6 +4559,7 @@ int whisper_full_with_state(

// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };

if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id;
Expand All @@ -4569,6 +4571,17 @@ int whisper_full_with_state(
}
}

{
const bool is_distil = ctx->model.hparams.n_text_layer == 2;

// distilled models require the "no_timestamps" token
// TODO: add input parameter (#1229)
if (is_distil) {
log("%s: using distilled model - forcing no_timestamps\n", __func__);
prompt_init.push_back(whisper_token_not(ctx));
}
}

int seek = seek_start;

std::vector<whisper_token> prompt;
Expand Down

0 comments on commit 39cfad0

Please sign in to comment.