Skip to content

Commit

Permalink
Merge pull request ggerganov#5 from bobqianic/push
Browse files Browse the repository at this point in the history
Push
  • Loading branch information
bobqianic committed Feb 5, 2024
2 parents e2e5177 + 7a5a2e9 commit a0d4348
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 143 deletions.
50 changes: 0 additions & 50 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,56 +616,6 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(

}


namespace utf_8 {
bool is_valid(const std::string &str) {
uint64_t count = 0; // Count of bytes in the current UTF-8 character

for (unsigned char c : str) {
if (count == 0) {
if ((c >> 5) == 0b110) count = 1; // 2-byte character
else if ((c >> 4) == 0b1110) count = 2; // 3-byte character
else if ((c >> 3) == 0b11110) count = 3; // 4-byte character
else if ((c >> 7) == 0b0) count = 0; // 1-byte character
else return false; // Invalid UTF-8
} else {
if ((c >> 6) != 0b10) return false; // Subsequent bytes should start with 10
count--;
}
}

return count == 0; // Ensure all UTF-8 characters are complete
}

std::vector<std::string> merge_and_split(const std::string &str) {
std::vector<std::string> result;
std::string buffer;
uint64_t count = 0; // Count of bytes in the current UTF-8 character

for (unsigned char c : str) {
if (count == 0) {
header:
if ((c >> 5) == 0b110) count = 1; // 2-byte character
else if ((c >> 4) == 0b1110) count = 2; // 3-byte character
else if ((c >> 3) == 0b11110) count = 3; // 4-byte character
else count = 0; // Invalid UTF-8 || 1-byte character
if (!buffer.empty()) result.push_back(buffer);
buffer.clear();
buffer += static_cast<char>(c);
} else {
if ((c >> 6) != 0b10) {
goto header;
} // Subsequent bytes should start with 10
buffer += static_cast<char>(c);
count--;
}
}

if (!buffer.empty()) result.push_back(buffer);
return result;
}
}

bool is_wav_buffer(const std::string buf) {
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
Expand Down
5 changes: 0 additions & 5 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
float repeat_penalty,
std::mt19937 & rng);

namespace utf_8{
bool is_valid(const std::string &str);
std::vector<std::string> merge_and_split(const std::string &str);
}

//
// Audio utils
//
Expand Down
37 changes: 16 additions & 21 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ struct whisper_params {
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
Expand Down Expand Up @@ -149,7 +148,6 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params)
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
Expand Down Expand Up @@ -197,7 +195,6 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
Expand Down Expand Up @@ -320,6 +317,10 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper


if (params.print_colors) {
std::string buffer;
float probability_sum = 0;
int count = 0;

for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
Expand All @@ -328,26 +329,21 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
}

const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
// if (utf_8::is_valid(text)) {
// printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
// } else {
printf("%s[_%i_]%s", k_colors[col].c_str(), whisper_full_get_token_id(ctx, i, j), "\033[0m");
// }
buffer += whisper_full_get_token_text(ctx, i, j);
probability_sum += whisper_full_get_token_p (ctx, i, j);
count++;
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(probability_sum/static_cast<float>(count), 3)*float(k_colors.size()))));

if (whisper_utf8_is_valid(buffer.c_str())) {
printf("%s%s%s", k_colors[col].c_str(), buffer.c_str(), "\033[0m");
buffer.clear();
probability_sum = 0;
count = 0;
}
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
for (auto &k : utf_8::merge_and_split(text)) {
if (utf_8::is_valid(k)) {
printf("%s", k.c_str());
} else {
for (auto l : k) {
printf("[_%i_]", l);
}
}
}
printf("%s", text);
}

if (params.tinydiarize) {
Expand Down Expand Up @@ -1016,7 +1012,6 @@ int run(int argc, const char ** argv) {
wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
Expand Down
8 changes: 0 additions & 8 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ struct whisper_params {
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool print_special = false;
bool print_colors = false;
Expand Down Expand Up @@ -136,7 +135,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
Expand Down Expand Up @@ -192,7 +190,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
Expand Down Expand Up @@ -462,10 +459,6 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content);
}
if (req.has_file("split_on_word"))
{
params.split_on_word = parse_str_to_bool(req.get_file_value("split_on_word").content);
}
if (req.has_file("no_timestamps"))
{
params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content);
Expand Down Expand Up @@ -738,7 +731,6 @@ int main(int argc, char ** argv) {

wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
Expand Down
Loading

0 comments on commit a0d4348

Please sign in to comment.