Skip to content

Commit

Permalink
whisper : add detect-language mode (ggerganov#853)
Browse files Browse the repository at this point in the history
* add detectlanguage flag

* renaming and help

* no idea why that last one didn't commit

* run language detection if dl is set

* help message fix

* various fixes

* fix quitting

* fix language being english on print
  • Loading branch information
CRD716 committed May 2, 2023
1 parent 27f7e8c commit fd7288f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
7 changes: 7 additions & 0 deletions examples/main/main.cpp
Expand Up @@ -66,6 +66,7 @@ struct whisper_params {

bool speed_up = false;
bool translate = false;
bool detect_language= false;
bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
Expand Down Expand Up @@ -141,6 +142,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
Expand Down Expand Up @@ -191,6 +193,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
Expand Down Expand Up @@ -739,6 +742,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
Expand All @@ -761,6 +767,7 @@ int main(int argc, char ** argv) {
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
Expand Down
6 changes: 5 additions & 1 deletion whisper.cpp
Expand Up @@ -3312,6 +3312,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.prompt_n_tokens =*/ 0,

/*.language =*/ "en",
/*.detect_language =*/ false,

/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
Expand Down Expand Up @@ -3898,7 +3899,7 @@ int whisper_full_with_state(
}

// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);

const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
Expand All @@ -3910,6 +3911,9 @@ int whisper_full_with_state(
params.language = whisper_lang_str(lang_id);

fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
if (params.detect_language) {
return 0;
}
}

if (params.token_timestamps) {
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Expand Up @@ -365,6 +365,7 @@ extern "C" {

// for auto-detection, set to nullptr, "" or "auto"
const char * language;
bool detect_language;

// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
Expand Down

0 comments on commit fd7288f

Please sign in to comment.