Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : improve handling of prompts #1981

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would document the 224 number here for quick reference.

Same in

// maximum of whisper_n_text_ctx()/2 tokens are used

fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
Expand Down
9 changes: 7 additions & 2 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to

if (n_max_tokens < (int) res.size()) {
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
return -1;
return -(int) res.size();
}

for (int i = 0; i < (int) res.size(); i++) {
Expand Down Expand Up @@ -5313,7 +5313,12 @@ int whisper_full_with_state(
// initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
if (n_needed < 0) {
prompt_tokens.resize(-n_needed);
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
}
prompt_tokens.resize(n_needed);
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}
Expand Down
4 changes: 3 additions & 1 deletion whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ extern "C" {
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns -1 on failure
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
Expand Down Expand Up @@ -503,6 +503,8 @@ extern "C" {

// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
// use whisper_tokenize() to convert text to tokens
// maximum of whisper_n_text_ctx()/2 tokens are used
const char * initial_prompt;
const whisper_token * prompt_tokens;
int prompt_n_tokens;
Expand Down
Loading