Skip to content

Commit

Permalink
release : v1.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Feb 28, 2023
1 parent d5c6d5c commit 92d4c5c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ publish: publish-trigger
\n\
cd /path/to/whisper.cpp/bindings/ios\n\
git commit\n\
git tag 1.2.0\n\
git tag 1.2.1\n\
git push origin master --tags\n\
"

Expand Down
29 changes: 29 additions & 0 deletions Sources/whisper/include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ extern "C" {
int n_samples,
int n_threads);

// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context* ctx,
const float* samples,
int n_samples,
int n_threads);


// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
Expand Down Expand Up @@ -233,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);

// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);

// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
Expand All @@ -257,6 +277,7 @@ extern "C" {
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_tokens; // max tokens per segment (0 = no limit)

// [EXPERIMENTAL] speed-up techniques
Expand All @@ -274,6 +295,7 @@ extern "C" {

// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253

float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
Expand Down Expand Up @@ -303,6 +325,10 @@ extern "C" {
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;

// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};

WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
Expand All @@ -329,6 +355,9 @@ extern "C" {
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);

// Language id associated with the current context
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);

// Get the start and end time of the specified segment.
WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
Expand Down
104 changes: 90 additions & 14 deletions Sources/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,14 +592,16 @@ struct whisper_context {

mutable std::mt19937 rng; // used for sampling at t > 0.0

int lang_id = 0; // english by default

// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy

// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
int32_t exp_n_audio_ctx = 0; // 0 - use default

void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
Expand Down Expand Up @@ -803,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));

// this is the memory required by one decoder
const size_t mem_required_decoder =
Expand Down Expand Up @@ -2903,7 +2905,7 @@ const char * whisper_print_system_info(void) {

struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
/*.strategy =*/ strategy,

/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
Expand All @@ -2922,6 +2924,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0,

/*.speed_up =*/ false,
Expand All @@ -2933,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en",

/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,

/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
Expand All @@ -2958,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,

/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};

switch (strategy) {
Expand Down Expand Up @@ -2988,9 +2995,35 @@ static void whisper_exp_compute_token_level_timestamps(
float thold_pt,
float thold_ptsum);

// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch);
}));
}

// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
return !std::isspace(ch);
}).base(), s.end());
}

// trim from both ends (in place)
static inline void trim(std::string &s) {
rtrim(s);
ltrim(s);
}

static inline bool should_split_on_word(const char * txt, bool split_on_word) {
if (!split_on_word) return true;

return txt[0] == ' ';
}

// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
auto segment = ctx.result_all.back();

int res = 1;
Expand All @@ -3005,11 +3038,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
}

const auto txt = whisper_token_to_str(&ctx, token.id);

const int cur = strlen(txt);

if (acc + cur > max_len && i > 0) {
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
// split here
if (split_on_word) {
trim(text);
}

ctx.result_all.back().text = std::move(text);
ctx.result_all.back().t1 = token.t0;
ctx.result_all.back().tokens.resize(i);
Expand Down Expand Up @@ -3037,16 +3073,26 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
}
}

if (split_on_word) {
trim(text);
}
ctx.result_all.back().text = std::move(text);

return res;
}

static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
};

// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
Expand Down Expand Up @@ -3102,6 +3148,31 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;

if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}

// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) {
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
}
}
}

// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" -")] = -INFINITY;
}
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" '")] = -INFINITY;
}
}

// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
Expand Down Expand Up @@ -3449,7 +3520,7 @@ int whisper_full(
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
return -3;
}

ctx->lang_id = lang_id;
params.language = whisper_lang_str(lang_id);

fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
Expand Down Expand Up @@ -3546,6 +3617,7 @@ int whisper_full(
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
ctx->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
Expand Down Expand Up @@ -3782,7 +3854,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});

int cur_c = 0;
uint32_t cur_c = 0;

for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
Expand All @@ -3793,7 +3865,7 @@ int whisper_full(

auto & cur = beam_candidates[cur_c++];

while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
++cur_c;
}

Expand Down Expand Up @@ -4069,7 +4141,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);

if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback) {
Expand Down Expand Up @@ -4113,7 +4185,7 @@ int whisper_full(
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);

if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, params.max_len);
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback) {
Expand Down Expand Up @@ -4266,6 +4338,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->result_all.size();
}

int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id;
}

int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t0;
}
Expand Down

0 comments on commit 92d4c5c

Please sign in to comment.