Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ set(PARAKEET_SRC
src/joint.cpp
src/tdt.cpp
src/rnnt.cpp
src/transducer_batch.cpp
src/tokenizer.cpp
src/search.cpp
src/transcription.cpp)
Expand Down
209 changes: 207 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
#include "audio_io.hpp"
#include "streaming.hpp"
#include "transcription.hpp"
#include "ggml_graph.hpp" // pk::set_num_threads
#include "ggml_graph.hpp" // pk::set_num_threads, pk::global_backend
#include "backend.hpp" // pk::ensure_weights_realized
#include "encoder.hpp"
#include "prediction.hpp"
#include "joint.hpp"
#include "tdt.hpp"
#include "rnnt.hpp"
#include "transducer_batch.hpp"
#include "mel.hpp"
#include "mel_gpu.hpp"
#include "ggml.h"
#include "gguf.h"
#include <chrono>
Expand Down Expand Up @@ -851,6 +860,198 @@ static int cmd_bench_batch(int argc, char** argv) {
return 0;
}

// ---------------------------------------------------------------------------
// bench-decode: encode ONE clip once, then time DECODE only -- serial (N
// separate tdt_greedy/rnnt_greedy calls over N copies of the encoder output)
// vs batched (transducer_greedy_batch over the same N copies) -- at several
// batch sizes. Reports decode wall-clock and the batched/serial speedup so the
// GPU win from batched decode can be measured in isolation from the encoder.
//
// The encoder cost is paid once and excluded; only the transducer decode loop
// is timed. Each rep is averaged over R repetitions (best/min recorded). The
// b=0 batched ids are compared against the serial ids (same clip, decode is
// deterministic) as a correctness sanity check.
// ---------------------------------------------------------------------------
static int cmd_bench_decode(int argc, char** argv) {
std::string model, audio;
std::string batch_sizes_str = "1,4,8,16";
int threads = 0; // 0 == unset -> use the persistent-backend default
int reps = 5;
for (int i = 0; i < argc; ++i) {
if (std::strcmp(argv[i], "--model") == 0 && i + 1 < argc) {
model = argv[++i];
} else if (std::strcmp(argv[i], "--audio") == 0 && i + 1 < argc) {
audio = argv[++i];
} else if (std::strcmp(argv[i], "--batch-sizes") == 0 && i + 1 < argc) {
batch_sizes_str = argv[++i];
} else if (std::strcmp(argv[i], "--threads") == 0 && i + 1 < argc) {
threads = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "--reps") == 0 && i + 1 < argc) {
reps = std::atoi(argv[++i]);
}
}
if (model.empty() || audio.empty()) {
std::fprintf(stderr,
"usage: parakeet-cli bench-decode --model <m.gguf> --audio <wav> "
"[--batch-sizes 1,4,8,16] [--threads N] [--reps R]\n");
return 2;
}
if (reps < 1) reps = 1;

// Parse --batch-sizes (comma-separated positive ints).
std::vector<int> batch_sizes;
{
std::stringstream ss(batch_sizes_str);
std::string tok;
while (std::getline(ss, tok, ',')) {
size_t b = tok.find_first_not_of(" \t");
if (b == std::string::npos) continue;
size_t e = tok.find_last_not_of(" \t");
int v = std::atoi(tok.substr(b, e - b + 1).c_str());
if (v > 0) batch_sizes.push_back(v);
}
}
if (batch_sizes.empty()) {
std::fprintf(stderr,
"parakeet-cli bench-decode: no valid --batch-sizes (want e.g. 1,4,8,16)\n");
return 2;
}

if (threads > 0) pk::set_num_threads(threads);
int reported_threads = threads > 0 ? threads : 8; // kDefaultThreads

// Load the model components over the lower-level loader (we need the encoder
// / prediction / joint pieces, not the high-level Model::transcribe path).
pk::ModelLoader ml;
if (!ml.load(model)) {
std::fprintf(stderr, "parakeet-cli bench-decode: failed to load model %s\n",
model.c_str());
return 1;
}
pk::ensure_weights_realized(ml);
pk::Encoder enc(ml);
pk::PredictionNet pred(ml);
pk::Joint joint(ml);
const auto& cfg = ml.config();
const int blank = (int)cfg.blank_id;
const int maxs = (int)cfg.max_symbols;
const std::vector<int32_t> durations = cfg.tdt_durations;

// Load the WAV.
pk::Audio a;
if (!pk::load_audio_16k_mono(audio, a)) {
std::fprintf(stderr, "parakeet-cli bench-decode: failed to load audio %s\n",
audio.c_str());
return 1;
}

// Mel front end (GpuMel on a non-CPU backend, else FFT MelFrontend), exactly
// as model.cpp's transcribe path does.
std::vector<float> feats;
int n_mels = 0, T = 0;
if (std::string(pk::global_backend().device_name()) != "cpu") {
pk::GpuMel gmel(ml);
gmel.compute(a.samples, feats, n_mels, T);
} else {
pk::MelFrontend mel(ml);
mel.compute(a.samples, feats, n_mels, T);
}

// Encoder -> enc_out [d_model, Tout] (channels-first); transpose to row-major
// enc_row [Tout, d_model] as the decoders expect.
std::vector<float> enc_out;
int dm = 0, Tout = 0;
enc.forward(feats, n_mels, T, enc_out, dm, Tout);
std::vector<float> enc_row((size_t)Tout * dm);
for (int t = 0; t < Tout; ++t)
for (int c = 0; c < dm; ++c)
enc_row[(size_t)t * dm + c] = enc_out[(size_t)c * Tout + t];

const bool use_tdt = !durations.empty();
auto decode_serial_one = [&]() -> std::vector<int32_t> {
return use_tdt
? pk::tdt_greedy(pred, joint, enc_row, Tout, dm, durations, blank, maxs, nullptr)
: pk::rnnt_greedy(pred, joint, enc_row, Tout, dm, blank, maxs, nullptr);
};

using clock = std::chrono::steady_clock;
auto ms_since = [](clock::time_point t0) {
return std::chrono::duration<double, std::milli>(clock::now() - t0).count();
};

// Warm up (untimed): realize weights / CUDA kernels for both paths.
std::vector<int32_t> serial_ref = decode_serial_one();
{
std::vector<std::vector<float>> encs1{enc_row};
std::vector<int> Ts1{Tout};
std::vector<std::vector<int32_t>> ids1;
pk::transducer_greedy_batch(pred, joint, encs1, Ts1, dm, durations,
blank, maxs, ids1, nullptr);
}

struct Row { int B; double serial_ms; double batched_ms; double speedup;
double serial_cps; double batched_cps; };
std::vector<Row> rows;
rows.reserve(batch_sizes.size());
bool sanity_ok = true;

for (int B : batch_sizes) {
std::vector<std::vector<float>> encs((size_t)B, enc_row);
std::vector<int> Ts((size_t)B, Tout);

// SERIAL: B separate single-clip decodes, best of R reps.
double serial_ms = 1e300;
for (int r = 0; r < reps; ++r) {
auto t0 = clock::now();
for (int b = 0; b < B; ++b) (void)decode_serial_one();
serial_ms = std::min(serial_ms, ms_since(t0));
}

// BATCHED: one transducer_greedy_batch over the B copies, best of R reps.
double batched_ms = 1e300;
std::vector<std::vector<int32_t>> ids_last;
for (int r = 0; r < reps; ++r) {
std::vector<std::vector<int32_t>> ids;
auto t0 = clock::now();
pk::transducer_greedy_batch(pred, joint, encs, Ts, dm, durations,
blank, maxs, ids, nullptr);
batched_ms = std::min(batched_ms, ms_since(t0));
ids_last = std::move(ids);
}

// Sanity: batched ids[0] must equal the serial decode of the same clip.
if (!ids_last.empty() && ids_last[0] != serial_ref) {
sanity_ok = false;
std::fprintf(stderr,
"parakeet-cli bench-decode: WARN B=%d batched ids[0] != serial "
"(%zu vs %zu tokens) -- batched decode may be buggy\n",
B, ids_last[0].size(), serial_ref.size());
}

double speedup = batched_ms > 0.0 ? serial_ms / batched_ms : 0.0;
double serial_cps = serial_ms > 0.0 ? (double)B / (serial_ms / 1000.0) : 0.0;
double batched_cps = batched_ms > 0.0 ? (double)B / (batched_ms / 1000.0) : 0.0;
rows.push_back({B, serial_ms, batched_ms, speedup, serial_cps, batched_cps});
}

// Human-readable table to stderr.
std::fprintf(stderr,
"\nbench-decode: clip Tout=%d frames, d_model=%d, decoder=%s, threads=%d, "
"reps=%d (best-of), backend=%s\n",
Tout, dm, use_tdt ? "tdt" : "rnnt", reported_threads, reps,
pk::global_backend().device_name());
std::fprintf(stderr, " %-6s %-12s %-12s %-10s %-14s %-14s\n",
"B", "serial_ms", "batched_ms", "speedup", "serial_cps", "batched_cps");
for (const Row& r : rows) {
std::fprintf(stderr, " %-6d %-12.2f %-12.2f %-10.2f %-14.1f %-14.1f\n",
r.B, r.serial_ms, r.batched_ms, r.speedup,
r.serial_cps, r.batched_cps);
}
std::fprintf(stderr, " sanity (batched ids[0]==serial): %s\n",
sanity_ok ? "OK" : "MISMATCH (see WARN above)");
return 0;
}

// Run a subcommand, then free the process-global backend while the GPU driver is
// still alive (the subcommand's local Model is already destroyed by the time it
// returns, releasing its device weight buffer). Avoids the CUDA "driver shutting
Expand All @@ -870,6 +1071,8 @@ int main(int argc, char** argv) {
return run_and_shutdown(cmd_quantize, argc - 2, argv + 2);
if (argc >= 2 && std::strcmp(argv[1], "bench-batch") == 0)
return run_and_shutdown(cmd_bench_batch, argc - 2, argv + 2);
if (argc >= 2 && std::strcmp(argv[1], "bench-decode") == 0)
return run_and_shutdown(cmd_bench_decode, argc - 2, argv + 2);
if (argc >= 2 && std::strcmp(argv[1], "bench") == 0)
return run_and_shutdown(cmd_bench, argc - 2, argv + 2);
std::fprintf(stderr,
Expand All @@ -882,6 +1085,8 @@ int main(int argc, char** argv) {
" parakeet-cli bench --model <model.gguf> --manifest <file> "
"[--decoder ctc|tdt] [--threads N] [--json <out>]\n"
" parakeet-cli bench-batch --model <model.gguf> --manifest <file> "
"[--decoder ctc|tdt] [--threads N] [--batch-sizes 1,4,8] [--json <out>]\n");
"[--decoder ctc|tdt] [--threads N] [--batch-sizes 1,4,8] [--json <out>]\n"
" parakeet-cli bench-decode --model <model.gguf> --audio <wav> "
"[--batch-sizes 1,4,8,16] [--threads N] [--reps R]\n");
return 2;
}
21 changes: 21 additions & 0 deletions src/decode_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
#include <cmath>
namespace pk {
// argmax over a[0..n): first index of the max (matches torch.max tie-break).
inline int decode_argmax(const float* a, int n) {
int best = 0; float bv = a[0];
for (int i = 1; i < n; ++i) if (a[i] > bv) { bv = a[i]; best = i; }
return best;
}
// NeMo rescaled max_prob confidence over a[0..n) at index k:
// conf = (N*p_max - 1)/(N - 1), p_max = softmax(a)[k]. Stable softmax.
inline float decode_max_prob_conf(const float* a, int n, int k) {
float mx = a[0];
for (int i = 1; i < n; ++i) if (a[i] > mx) mx = a[i];
double denom = 0.0;
for (int i = 0; i < n; ++i) denom += std::exp((double)a[i] - (double)mx);
const double p_max = std::exp((double)a[k] - (double)mx) / denom;
const double N = (double)n;
return (float)((N * p_max - 1.0) / (N - 1.0));
}
} // namespace pk
39 changes: 39 additions & 0 deletions src/joint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,45 @@ void Joint::step_logits(const float* enc_proj_t,
assert(ok && "step_logits graph failed");
}

void Joint::step_logits_batch(const float* enc_proj_gathered,
const float* g, int pred_hidden, int n,
std::vector<float>& logits) const {
assert(pred_hidden == pred_hidden_ && "pred_hidden mismatch");
const int H = joint_hidden_;

// Batched per-step joint over N items on the PERSISTENT backend. Mirrors
// step_logits with a batch axis (ggml ne1 = N): each of the two matmuls is
// applied across all N columns at once, and the biases broadcast over N.
// N=1 reduces exactly to step_logits. The gathered enc_proj input holds one
// joint_hidden row per item (item k at offset k*H), and g holds one
// pred_hidden vector per item (item k at offset k*pred_hidden).
bool ok = pk::run_graph(0, 0,
[&](ggml_context* ctx) -> ggml_tensor* {
// Gathered enc_proj rows: [H, N].
int64_t ep_ne[2] = { H, n };
ggml_tensor* ep = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, ep_ne,
enc_proj_gathered, (size_t)H * n * sizeof(float));
// Batched pred-net output g: [P, N].
int64_t g_ne[2] = { pred_hidden_, n };
ggml_tensor* gv = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, g_ne,
g, (size_t)pred_hidden_ * n * sizeof(float));
// pred_proj = pred.weight·g + pred.bias (P->H). Weight ne=[P,H].
ggml_tensor* Wp = pk::clone_weight(ctx, ml_, "joint.pred.weight");
ggml_tensor* pp = ggml_mul_mat(ctx, Wp, gv); // [H, N]
ggml_tensor* bp = pk::clone_weight(ctx, ml_, "joint.pred.bias");
pp = ggml_add(ctx, pp, bp); // bp [H] broadcasts over N
// f = ReLU(enc_proj + pred_proj)
ggml_tensor* f = ggml_relu(ctx, ggml_add(ctx, ep, pp)); // [H, N]
// logits = joint_net.2.weight·f + joint_net.2.bias (H->V). Weight ne=[H,V].
ggml_tensor* Wo = pk::clone_weight(ctx, ml_, "joint.joint_net.2.weight");
ggml_tensor* y = ggml_mul_mat(ctx, Wo, f); // [V, N]
ggml_tensor* bo = pk::clone_weight(ctx, ml_, "joint.joint_net.2.bias");
y = ggml_add(ctx, y, bo); // bo [V] broadcasts over N
return y; // [V_plus, N]
}, logits);
assert(ok && "step_logits_batch graph failed");
}

void Joint::forward(const std::vector<float>& enc, int T, int enc_hidden,
const std::vector<float>& pred, int U, int pred_hidden,
std::vector<float>& logits, int& V_plus_out) const {
Expand Down
9 changes: 9 additions & 0 deletions src/joint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ class Joint {
const float* g, int pred_hidden,
std::vector<float>& logits) const;

// Batched per-step joint for N items.
// enc_proj_gathered: [joint_hidden * N], each item's enc_proj row for its
// current frame (item n at offset n*joint_hidden).
// g: [pred_hidden * N], batched pred output (item n at n*pred_hidden).
// logits out: [V_plus * N], item n at offset n*V_plus.
void step_logits_batch(const float* enc_proj_gathered,
const float* g, int pred_hidden, int n,
std::vector<float>& logits) const;

int joint_hidden() const { return joint_hidden_; }

// V_plus = vocab_size + 1 + num_durations
Expand Down
Loading
Loading