Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node : add audio_ctx and audio buffer params #2123

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/addon.node/__test__/whisper.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const whisperParamsMock = {
comma_in_time: false,
translate: true,
no_timestamps: false,
audio_ctx: 0,
};

describe("Run whisper.node", () => {
Expand Down
47 changes: 39 additions & 8 deletions examples/addon.node/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct whisper_params {
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
int32_t audio_ctx = 0;

float word_thold = 0.01f;
float entropy_thold = 2.4f;
Expand Down Expand Up @@ -46,6 +47,8 @@ struct whisper_params {

std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};

std::vector<float> pcmf32 = {}; // mono-channel F32 PCM
};

struct whisper_print_user_data {
Expand Down Expand Up @@ -125,13 +128,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {}

int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {

if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}

if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
if (params.fname_inp.empty() && params.pcmf32.empty()) {
fprintf(stderr, "error: no input files or audio buffer specified\n");
return 2;
}

Expand All @@ -151,16 +153,29 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
return 3;
}

// if params.pcmf32 is provided, set params.fname_inp to "buffer"
// this is simpler than further modifications in the code
if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear();
params.fname_inp.emplace_back("buffer");
}

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];

std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM

if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
// read the input audio file if params.pcmf32 is not provided
if (params.pcmf32.empty()) {
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
}
} else {
pcmf32 = params.pcmf32;
}

// print system information
Expand All @@ -180,12 +195,13 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
params.no_timestamps ? 0 : 1,
params.audio_ctx);

fprintf(stderr, "\n");
}
Expand All @@ -212,6 +228,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx;

wparams.speed_up = params.speed_up;

Expand Down Expand Up @@ -311,14 +328,28 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();

Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec;
if (pcmf32Value.IsTypedArray()) {
Napi::Float32Array pcmf32 = pcmf32Value.As<Napi::Float32Array>();
size_t length = pcmf32.ElementLength();
pcmf32_vec.reserve(length);
for (size_t i = 0; i < length; i++) {
pcmf32_vec.push_back(pcmf32[i]);
}
}

params.language = language;
params.model = model;
params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu;
params.no_prints = no_prints;
params.no_timestamps = no_timestamps;
params.audio_ctx = audio_ctx;
params.pcmf32 = pcmf32_vec;
params.comma_in_time = comma_in_time;

Napi::Function callback = info[1].As<Napi::Function>();
Expand Down
9 changes: 8 additions & 1 deletion examples/addon.node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ const whisperParams = {
comma_in_time: false,
translate: true,
no_timestamps: false,
audio_ctx: 0,
};

const arguments = process.argv.slice(2);
const params = Object.fromEntries(
arguments.reduce((pre, item) => {
if (item.startsWith("--")) {
return [...pre, item.slice(2).split("=")];
const [key, value] = item.slice(2).split("=");
if (key === "audio_ctx") {
whisperParams[key] = parseInt(value);
} else {
whisperParams[key] = value;
}
return pre;
}
return pre;
}, [])
Expand Down
Loading