From b9ca3f67217ad9290c9268f303c6cafa73b96c37 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:18:58 +0000 Subject: [PATCH 01/11] refactor(decode): share argmax/max_prob_conf in decode_common.hpp Co-Authored-By: Claude Opus 4.8 (1M context) --- src/decode_common.hpp | 21 +++++++++++++++++++++ src/rnnt.cpp | 33 +++------------------------------ src/tdt.cpp | 36 ++++-------------------------------- 3 files changed, 28 insertions(+), 62 deletions(-) create mode 100644 src/decode_common.hpp diff --git a/src/decode_common.hpp b/src/decode_common.hpp new file mode 100644 index 0000000..5bae7e0 --- /dev/null +++ b/src/decode_common.hpp @@ -0,0 +1,21 @@ +#pragma once +#include +namespace pk { +// argmax over a[0..n): first index of the max (matches torch.max tie-break). +inline int decode_argmax(const float* a, int n) { + int best = 0; float bv = a[0]; + for (int i = 1; i < n; ++i) if (a[i] > bv) { bv = a[i]; best = i; } + return best; +} +// NeMo rescaled max_prob confidence over a[0..n) at index k: +// conf = (N*p_max - 1)/(N - 1), p_max = softmax(a)[k]. Stable softmax. +inline float decode_max_prob_conf(const float* a, int n, int k) { + float mx = a[0]; + for (int i = 1; i < n; ++i) if (a[i] > mx) mx = a[i]; + double denom = 0.0; + for (int i = 0; i < n; ++i) denom += std::exp((double)a[i] - (double)mx); + const double p_max = std::exp((double)a[k] - (double)mx) / denom; + const double N = (double)n; + return (float)((N * p_max - 1.0) / (N - 1.0)); +} +} // namespace pk diff --git a/src/rnnt.cpp b/src/rnnt.cpp index e607720..3c32517 100644 --- a/src/rnnt.cpp +++ b/src/rnnt.cpp @@ -1,37 +1,10 @@ #include "rnnt.hpp" +#include "decode_common.hpp" #include #include namespace pk { -namespace { -// argmax over a[0..n) returning the first index of the maximum value. -// torch.max(dim) returns the FIRST max index on ties; match that. -int argmax(const float* a, int n) { - int best = 0; - float bv = a[0]; - for (int i = 1; i < n; ++i) { - if (a[i] > bv) { bv = a[i]; best = i; } - } - return best; -} - -// NeMo's rescaled `max_prob` confidence (method 'max_prob', alpha==1.0): -// conf = (N * p_max - 1) / (N - 1), p_max = softmax(logits)[k]. -// For RNN-T the confidence slice is the FULL joint output vector (V_plus = -// vocab + 1, blank included; no durations) — NeMo log_softmaxes the whole -// joint output. N == n (the slice size). Stable softmax (subtract the max). -float max_prob_conf_logits(const float* a, int n, int k) { - float mx = a[0]; - for (int i = 1; i < n; ++i) if (a[i] > mx) mx = a[i]; - double denom = 0.0; - for (int i = 0; i < n; ++i) denom += std::exp((double)a[i] - (double)mx); - const double p_max = std::exp((double)a[k] - (double)mx) / denom; - const double N = (double)n; - return (float)((N * p_max - 1.0) / (N - 1.0)); -} -} // namespace - RnntDecodeState rnnt_decode_init(const PredictionNet& pred) { RnntDecodeState st; st.state = pred.zero_state(); @@ -103,7 +76,7 @@ std::vector rnnt_decode_frames(const PredictionNet& pred, const Joint& joint.step_logits(enc_proj.data() + (size_t)t * H, g.data(), (int)g.size(), logits); - const int k = argmax(logits.data(), token_count); + const int k = decode_argmax(logits.data(), token_count); // Blank -> stop emitting at this frame and advance time. if (k == blank_id) break; @@ -117,7 +90,7 @@ std::vector rnnt_decode_frames(const PredictionNet& pred, const Joint& // max_prob confidence): frame = the (local) encoder frame t at // emission, conf = max_prob over the full joint output vector // (N = V_plus = vocab+1), span = 1 (RNN-T advances one frame). - const float conf = max_prob_conf_logits(logits.data(), token_count, k); + const float conf = decode_max_prob_conf(logits.data(), token_count, k); tokens->push_back(TokenInfo{ (int32_t)k, (int32_t)t, conf, 1 }); } st.last_token = (int32_t)k; diff --git a/src/tdt.cpp b/src/tdt.cpp index f731ce7..72b5f71 100644 --- a/src/tdt.cpp +++ b/src/tdt.cpp @@ -1,38 +1,10 @@ #include "tdt.hpp" +#include "decode_common.hpp" #include #include namespace pk { -namespace { -// argmax over a[0..n) returning the first index of the maximum value. -// torch.max(dim) returns the FIRST max index on ties; match that. -int argmax(const float* a, int n) { - int best = 0; - float bv = a[0]; - for (int i = 1; i < n; ++i) { - if (a[i] > bv) { bv = a[i]; best = i; } - } - return best; -} - -// NeMo's rescaled `max_prob` confidence (method 'max_prob', alpha==1.0): -// conf = (N * p_max - 1) / (N - 1), p_max = softmax(slice)[argmax]. -// Computed numerically from the RAW logit slice a[0..n): p_max is the softmax -// probability of the argmax over the slice (equivalently exp of the max -// log_softmax value), and N == n (the slice size = num token classes incl. -// blank). Stable softmax (subtract the max). -float max_prob_conf_logits(const float* a, int n, int k) { - float mx = a[0]; - for (int i = 1; i < n; ++i) if (a[i] > mx) mx = a[i]; - double denom = 0.0; - for (int i = 0; i < n; ++i) denom += std::exp((double)a[i] - (double)mx); - const double p_max = std::exp((double)a[k] - (double)mx) / denom; - const double N = (double)n; - return (float)((N * p_max - 1.0) / (N - 1.0)); -} -} // namespace - std::vector tdt_greedy(const PredictionNet& pred, const Joint& joint, const std::vector& enc, int T, int enc_hidden, const std::vector& durations, @@ -102,8 +74,8 @@ std::vector tdt_greedy(const PredictionNet& pred, const Joint& joint, g.data(), (int)g.size(), logits); // Split: token logits [0, token_count), duration logits [token_count, V_plus). - const int k = argmax(logits.data(), token_count); - const int d_k = argmax(logits.data() + token_count, num_dur); + const int k = decode_argmax(logits.data(), token_count); + const int d_k = decode_argmax(logits.data() + token_count, num_dur); skip = durations[d_k]; // Commit state + last_token ONLY when k != blank. @@ -117,7 +89,7 @@ std::vector tdt_greedy(const PredictionNet& pred, const Joint& joint, // (NeMo log_softmaxes that slice; exclude the duration // logits). N = token_count = vocab + 1. // span = durations[d_k] (the duration/skip applied to the token). - const float conf = max_prob_conf_logits(logits.data(), token_count, k); + const float conf = decode_max_prob_conf(logits.data(), token_count, k); tokens->push_back(TokenInfo{ (int32_t)k, (int32_t)t, conf, (int32_t)skip }); } From fc19481c85236b7b6d45a2659e24a290814ec7e5 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:23:30 +0000 Subject: [PATCH 02/11] feat(prediction): step_batch (batched LSTM, [H,N]) Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prediction.cpp | 81 ++++++++++++++++++++++++++++ src/prediction.hpp | 19 +++++++ tests/CMakeLists.txt | 5 +- tests/test_prediction_step_batch.cpp | 33 ++++++++++++ 4 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 tests/test_prediction_step_batch.cpp diff --git a/src/prediction.cpp b/src/prediction.cpp index d21ec04..5e5e2d5 100644 --- a/src/prediction.cpp +++ b/src/prediction.cpp @@ -112,6 +112,87 @@ void PredictionNet::step(int32_t token_id, bool is_sos, assert(ok && "pred-net step graph failed"); } +// --------------------------------------------------------------------------- +// Batched single-step advance: the same LSTM math as step(), but with a batch +// axis N. Inputs and state are laid out [H, N] in ggml (item n is column n, +// offset n*H in the flat host buffer). Gate slices become [H, N] views into the +// [4H, N] z, using z->nb[1] as the column stride. N=1 reduces to step(). +// --------------------------------------------------------------------------- +void PredictionNet::step_batch(const std::vector& token_ids, + const std::vector& is_sos, + const BatchedPredState& in, + std::vector& g, + BatchedPredState& out_state) const { + const int H = H_; + const int L = n_layers_; + const int N = (int)token_ids.size(); + assert(N > 0 && (int)is_sos.size() == N && "batch size mismatch"); + + // Lazily fetch the embedding table to the host (device-safe), exactly as + // step() does. + if (embed_host_.empty()) { + pk::ensure_weights_realized(ml_); + ggml_tensor* emb = ml_.tensor("decoder.prediction.embed.weight"); + assert(emb && "missing decoder.prediction.embed.weight"); + embed_host_.resize((size_t)vocab_p1_ * H); + ggml_backend_tensor_get(emb, embed_host_.data(), 0, ggml_nbytes(emb)); + } + + // Layer-0 input [H*N]: zeros for SOS items, else the embedding row. + std::vector x0((size_t)H * N, 0.0f); + for (int n = 0; n < N; ++n) { + if (!is_sos[n]) { + assert(token_ids[n] >= 0 && token_ids[n] < vocab_p1_ && "embedding id out of range"); + std::memcpy(&x0[(size_t)n * H], &embed_host_[(size_t)token_ids[n] * H], + (size_t)H * sizeof(float)); + } + } + + out_state.h.assign((size_t)L, std::vector((size_t)H * N)); + out_state.c.assign((size_t)L, std::vector((size_t)H * N)); + + bool ok = pk::run_graph(0, 0, [&](ggml_context* ctx) -> ggml_tensor* { + int64_t ne2[2] = { H, N }; + ggml_tensor* layer_in = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, ne2, + x0.data(), (size_t)H * N * sizeof(float)); + ggml_tensor* top_h = nullptr; + for (int l = 0; l < L; ++l) { + const std::string s = "_l" + std::to_string(l); + ggml_tensor* Wih = pk::clone_weight(ctx, ml_, + ("decoder.prediction.dec_rnn.lstm.weight_ih" + s).c_str()); + ggml_tensor* Whh = pk::clone_weight(ctx, ml_, + ("decoder.prediction.dec_rnn.lstm.weight_hh" + s).c_str()); + ggml_tensor* bih = pk::clone_weight(ctx, ml_, + ("decoder.prediction.dec_rnn.lstm.bias_ih" + s).c_str()); + ggml_tensor* bhh = pk::clone_weight(ctx, ml_, + ("decoder.prediction.dec_rnn.lstm.bias_hh" + s).c_str()); + ggml_tensor* h_in = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, ne2, + in.h[l].data(), (size_t)H * N * sizeof(float)); + ggml_tensor* c_in = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, ne2, + in.c[l].data(), (size_t)H * N * sizeof(float)); + // z = W_ih·x + b_ih + W_hh·h_in + b_hh [4H, N] + // (bias [4H] broadcasts over the N columns). + ggml_tensor* z = ggml_add(ctx, + ggml_add(ctx, ggml_mul_mat(ctx, Wih, layer_in), bih), + ggml_add(ctx, ggml_mul_mat(ctx, Whh, h_in), bhh)); + // Gate slices (i, f, g, o), each [H, N]; column stride is z->nb[1]. + ggml_tensor* i = ggml_sigmoid(ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], 0))); + ggml_tensor* f = ggml_sigmoid(ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], (size_t)H * sizeof(float)))); + ggml_tensor* gg = ggml_tanh (ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], (size_t)2 * H * sizeof(float)))); + ggml_tensor* o = ggml_sigmoid(ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], (size_t)3 * H * sizeof(float)))); + // c' = f*c_in + i*g ; h' = o*tanh(c') + ggml_tensor* c_out = ggml_add(ctx, ggml_mul(ctx, f, c_in), ggml_mul(ctx, i, gg)); + ggml_tensor* h_out = ggml_mul(ctx, o, ggml_tanh(ctx, c_out)); + pk::capture_graph_output(c_out, &out_state.c[l]); + pk::capture_graph_output(h_out, &out_state.h[l]); + layer_in = h_out; + top_h = h_out; + } + return top_h; + }, g); + assert(ok && "pred-net step_batch graph failed"); +} + // --------------------------------------------------------------------------- // Full-sequence forward pass (unchanged API; now driven by step() so there is a // single LSTM implementation). Carries (h, c) state across timesteps; the diff --git a/src/prediction.hpp b/src/prediction.hpp index 54cb6ca..6e5b8ed 100644 --- a/src/prediction.hpp +++ b/src/prediction.hpp @@ -14,6 +14,13 @@ struct PredState { std::vector> c; // c[layer] = [hidden] }; +// Batched LSTM state: one (h,c) per layer, each holding N items' columns laid +// out [H*N] (item n at offset n*H). Generalizes PredState to a batch. +struct BatchedPredState { + std::vector> h; // h[layer] size H*N + std::vector> c; // c[layer] size H*N +}; + // RNN-Transducer prediction network — NeMo RNNTDecoder prediction net. // // Architecture: @@ -67,6 +74,18 @@ class PredictionNet { std::vector& g, PredState& out_state) const; + // Advance the LSTM one token for N items at once. + // token_ids[n]: embedding index for item n (ignored where is_sos[n]). + // is_sos[n]: use the zero SOS input for item n (1=true). + // in: batched prior state (h[L],c[L] each [H*N]). + // g: OUT, top-layer h' for all items [H*N] (item n at n*H). + // out_state: OUT, new batched (h',c'). + void step_batch(const std::vector& token_ids, + const std::vector& is_sos, + const BatchedPredState& in, + std::vector& g, + BatchedPredState& out_state) const; + int hidden_size() const { return H_; } int num_layers() const { return n_layers_; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4100ee7..946f732 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,7 @@ pk_add_test(test_streaming_encoder) pk_add_test(test_ctc) pk_add_test(test_prediction) pk_add_test(test_prediction_step) +pk_add_test(test_prediction_step_batch) pk_add_test(test_joint) pk_add_test(test_transducer_core) pk_add_test(test_tdt_greedy) @@ -50,7 +51,7 @@ pk_add_test(test_capi_batch_json) set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling test_subsampling_batch test_relpos_attention test_relpos_attention_batch test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_eou test_streaming_encoder test_ctc test_prediction - test_prediction_step + test_prediction_step test_prediction_step_batch test_joint test_transducer_core test_tdt_greedy test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b @@ -61,7 +62,7 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te # These tests read fixtures/baselines via paths relative to the project root. set_tests_properties(test_mel test_mel_gpu test_subsampling test_subsampling_batch test_relpos_attention test_relpos_attention_batch test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_eou test_streaming_encoder - test_ctc test_prediction test_prediction_step + test_ctc test_prediction test_prediction_step test_prediction_step_batch test_joint test_transducer_core test_tdt_greedy test_timestamps_tokens test_timestamps test_transcribe_batch_ts diff --git a/tests/test_prediction_step_batch.cpp b/tests/test_prediction_step_batch.cpp new file mode 100644 index 0000000..bac4201 --- /dev/null +++ b/tests/test_prediction_step_batch.cpp @@ -0,0 +1,33 @@ +#include "prediction.hpp" +#include "model_loader.hpp" +#include "parity.hpp" +#include +#include +#include +#include +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); + if (!gguf) { std::fprintf(stderr, "env not set; skip\n"); return 77; } + pk::ModelLoader ml; if (!ml.load(gguf)) return 1; + pk::PredictionNet pred(ml); + const int H = pred.hidden_size(), L = pred.num_layers(); + std::vector toks = {5, 0, 42}; + std::vector sos = {0, 1, 0}; + const int N = 3; + std::vector> g_ref(N); + pk::PredState z = pred.zero_state(); + for (int n = 0; n < N; ++n) { pk::PredState os; pred.step(toks[n], sos[n], z, g_ref[n], os); } + pk::BatchedPredState bin; + bin.h.assign(L, std::vector((size_t)H*N, 0.0f)); + bin.c.assign(L, std::vector((size_t)H*N, 0.0f)); + std::vector gb; pk::BatchedPredState bout; + pred.step_batch(toks, sos, bin, gb, bout); + bool ok = (int)gb.size() == H*N; + for (int n = 0; n < N && ok; ++n) { + std::vector col(gb.begin()+(size_t)n*H, gb.begin()+(size_t)(n+1)*H); + ok = pktest::compare(col, g_ref[n], "predbatch.g", 1e-4f, 1e-4f) && ok; + } + // also check out_state top-layer h column matches g (sanity) and sizes + if (ok && ((int)bout.h.size()!=L || (int)bout.h[L-1].size()!=H*N)) { std::fprintf(stderr,"state shape\n"); ok=false; } + return ok ? 0 : 1; +} From e5846ef583fddfa4d05ea968afad7680b4660c36 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:27:39 +0000 Subject: [PATCH 03/11] docs(prediction): explain the 4H gate-slice stride in step_batch Co-Authored-By: Claude Opus 4.8 (1M context) --- src/prediction.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/prediction.cpp b/src/prediction.cpp index 5e5e2d5..3ef8066 100644 --- a/src/prediction.cpp +++ b/src/prediction.cpp @@ -175,7 +175,10 @@ void PredictionNet::step_batch(const std::vector& token_ids, ggml_tensor* z = ggml_add(ctx, ggml_add(ctx, ggml_mul_mat(ctx, Wih, layer_in), bih), ggml_add(ctx, ggml_mul_mat(ctx, Whh, h_in), bhh)); - // Gate slices (i, f, g, o), each [H, N]; column stride is z->nb[1]. + // Gate slices (i, f, g, o), each [H, N]. The view keeps z's FULL + // column stride (z->nb[1] = 4H elems), reading only H contiguous + // elements per column, so consecutive columns skip the other three + // gate blocks. (Do NOT change the stride to H*sizeof(float).) ggml_tensor* i = ggml_sigmoid(ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], 0))); ggml_tensor* f = ggml_sigmoid(ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], (size_t)H * sizeof(float)))); ggml_tensor* gg = ggml_tanh (ctx, ggml_cont(ctx, ggml_view_2d(ctx, z, H, N, z->nb[1], (size_t)2 * H * sizeof(float)))); From cc9834d3a890c103b989180b5af76455195b92fe Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:29:35 +0000 Subject: [PATCH 04/11] feat(joint): step_logits_batch (batched joint, [V_plus,N]) Co-Authored-By: Claude Opus 4.8 (1M context) --- src/joint.cpp | 39 ++++++++++++++++++++++++++++ src/joint.hpp | 9 +++++++ tests/CMakeLists.txt | 5 ++-- tests/test_joint_step_batch.cpp | 46 +++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/test_joint_step_batch.cpp diff --git a/src/joint.cpp b/src/joint.cpp index 8e20df2..90cc24e 100644 --- a/src/joint.cpp +++ b/src/joint.cpp @@ -103,6 +103,45 @@ void Joint::step_logits(const float* enc_proj_t, assert(ok && "step_logits graph failed"); } +void Joint::step_logits_batch(const float* enc_proj_gathered, + const float* g, int pred_hidden, int n, + std::vector& logits) const { + assert(pred_hidden == pred_hidden_ && "pred_hidden mismatch"); + const int H = joint_hidden_; + + // Batched per-step joint over N items on the PERSISTENT backend. Mirrors + // step_logits with a batch axis (ggml ne1 = N): each of the two matmuls is + // applied across all N columns at once, and the biases broadcast over N. + // N=1 reduces exactly to step_logits. The gathered enc_proj input holds one + // joint_hidden row per item (item k at offset k*H), and g holds one + // pred_hidden vector per item (item k at offset k*pred_hidden). + bool ok = pk::run_graph(0, 0, + [&](ggml_context* ctx) -> ggml_tensor* { + // Gathered enc_proj rows: [H, N]. + int64_t ep_ne[2] = { H, n }; + ggml_tensor* ep = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, ep_ne, + enc_proj_gathered, (size_t)H * n * sizeof(float)); + // Batched pred-net output g: [P, N]. + int64_t g_ne[2] = { pred_hidden_, n }; + ggml_tensor* gv = pk::graph_input_tensor(ctx, GGML_TYPE_F32, 2, g_ne, + g, (size_t)pred_hidden_ * n * sizeof(float)); + // pred_proj = pred.weight·g + pred.bias (P->H). Weight ne=[P,H]. + ggml_tensor* Wp = pk::clone_weight(ctx, ml_, "joint.pred.weight"); + ggml_tensor* pp = ggml_mul_mat(ctx, Wp, gv); // [H, N] + ggml_tensor* bp = pk::clone_weight(ctx, ml_, "joint.pred.bias"); + pp = ggml_add(ctx, pp, bp); // bp [H] broadcasts over N + // f = ReLU(enc_proj + pred_proj) + ggml_tensor* f = ggml_relu(ctx, ggml_add(ctx, ep, pp)); // [H, N] + // logits = joint_net.2.weight·f + joint_net.2.bias (H->V). Weight ne=[H,V]. + ggml_tensor* Wo = pk::clone_weight(ctx, ml_, "joint.joint_net.2.weight"); + ggml_tensor* y = ggml_mul_mat(ctx, Wo, f); // [V, N] + ggml_tensor* bo = pk::clone_weight(ctx, ml_, "joint.joint_net.2.bias"); + y = ggml_add(ctx, y, bo); // bo [V] broadcasts over N + return y; // [V_plus, N] + }, logits); + assert(ok && "step_logits_batch graph failed"); +} + void Joint::forward(const std::vector& enc, int T, int enc_hidden, const std::vector& pred, int U, int pred_hidden, std::vector& logits, int& V_plus_out) const { diff --git a/src/joint.hpp b/src/joint.hpp index 2163c01..8369cb2 100644 --- a/src/joint.hpp +++ b/src/joint.hpp @@ -66,6 +66,15 @@ class Joint { const float* g, int pred_hidden, std::vector& logits) const; + // Batched per-step joint for N items. + // enc_proj_gathered: [joint_hidden * N], each item's enc_proj row for its + // current frame (item n at offset n*joint_hidden). + // g: [pred_hidden * N], batched pred output (item n at n*pred_hidden). + // logits out: [V_plus * N], item n at offset n*V_plus. + void step_logits_batch(const float* enc_proj_gathered, + const float* g, int pred_hidden, int n, + std::vector& logits) const; + int joint_hidden() const { return joint_hidden_; } // V_plus = vocab_size + 1 + num_durations diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 946f732..159887e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,6 +28,7 @@ pk_add_test(test_prediction) pk_add_test(test_prediction_step) pk_add_test(test_prediction_step_batch) pk_add_test(test_joint) +pk_add_test(test_joint_step_batch) pk_add_test(test_transducer_core) pk_add_test(test_tdt_greedy) pk_add_test(test_timestamps_tokens) @@ -52,7 +53,7 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_eou test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch - test_joint test_transducer_core test_tdt_greedy + test_joint test_joint_step_batch test_transducer_core test_tdt_greedy test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou @@ -63,7 +64,7 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te set_tests_properties(test_mel test_mel_gpu test_subsampling test_subsampling_batch test_relpos_attention test_relpos_attention_batch test_conformer test_conformer_batch test_conv_eou test_encoder test_encoder_batch test_encoder_eou test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch - test_joint + test_joint test_joint_step_batch test_transducer_core test_tdt_greedy test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe diff --git a/tests/test_joint_step_batch.cpp b/tests/test_joint_step_batch.cpp new file mode 100644 index 0000000..a76b421 --- /dev/null +++ b/tests/test_joint_step_batch.cpp @@ -0,0 +1,46 @@ +#include "joint.hpp" +#include "prediction.hpp" +#include "encoder.hpp" +#include "model_loader.hpp" +#include "parity.hpp" +#include +#include +#include +#include +int main() { + const char* gguf = std::getenv("PARAKEET_TEST_GGUF"); + const char* base = std::getenv("PARAKEET_TEST_BASELINE"); + if (!gguf || !base) { std::fprintf(stderr, "env not set; skip\n"); return 77; } + pk::ModelLoader ml; if (!ml.load(gguf)) return 1; + std::vector mel; std::vector ms; + if (!pktest::load_baseline(base, "mel", mel, ms)) return 1; + const int n_mels=(int)ms[0], T=(int)ms[1]; + pk::Encoder enc(ml); std::vector eo; int dm=0,Tout=0; + enc.forward(mel,n_mels,T,eo,dm,Tout); + std::vector encr((size_t)Tout*dm); + for (int t=0;t ep; joint.precompute_enc_proj(encr, Tout, dm, ep); + const int Hj = joint.joint_hidden(), Vp = joint.V_plus(); + pk::PredictionNet pred(ml); const int Hp = pred.hidden_size(); + std::vector frames = {0, Tout/2, Tout-1}; + std::vector toks = {7, 13, 21}; + const int N = 3; + std::vector> gs(N); + pk::PredState z = pred.zero_state(); + for (int n=0;n> ref(N); + for (int n=0;n epg((size_t)Hj*N), gg((size_t)Hp*N); + for (int n=0;n lb; joint.step_logits_batch(epg.data(), gg.data(), Hp, N, lb); + bool ok = (int)lb.size()==Vp*N; + for (int n=0;n col(lb.begin()+(size_t)n*Vp, lb.begin()+(size_t)(n+1)*Vp); + ok = pktest::compare(col, ref[n], "jointbatch", 1e-3f, 1e-3f) && ok; + } + return ok?0:1; +} From 093aacc975bd435324023aeba987651788c3b644 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:38:22 +0000 Subject: [PATCH 05/11] feat(decode): transducer_greedy_batch (batched RNNT+TDT greedy, bit-exact parity) Co-Authored-By: Claude Opus 4.8 (1M context) --- CMakeLists.txt | 1 + src/transducer_batch.cpp | 188 ++++++++++++++++++++ src/transducer_batch.hpp | 21 +++ tests/CMakeLists.txt | 7 +- tests/test_transducer_greedy_batch.cpp | 54 ++++++ tests/test_transducer_greedy_batch_rnnt.cpp | 51 ++++++ 6 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 src/transducer_batch.cpp create mode 100644 src/transducer_batch.hpp create mode 100644 tests/test_transducer_greedy_batch.cpp create mode 100644 tests/test_transducer_greedy_batch_rnnt.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 39c4f95..a8afde8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,6 +76,7 @@ set(PARAKEET_SRC src/joint.cpp src/tdt.cpp src/rnnt.cpp + src/transducer_batch.cpp src/tokenizer.cpp src/search.cpp src/transcription.cpp) diff --git a/src/transducer_batch.cpp b/src/transducer_batch.cpp new file mode 100644 index 0000000..e387680 --- /dev/null +++ b/src/transducer_batch.cpp @@ -0,0 +1,188 @@ +#include "transducer_batch.hpp" +#include "decode_common.hpp" +#include +#include + +namespace pk { + +// Batched greedy transducer decode. This is a faithful transposition of the +// per-item loops in tdt.cpp (tdt_greedy) and rnnt.cpp (rnnt_decode_frames), +// run for N items in lockstep "rounds". Each round runs ONE batched prediction +// step and ONE batched joint step over all active items, then applies the +// EXACT per-item rule from the oracle to each item independently. +// +// Parity rationale: +// - The per-item `g_valid` cache (skip the LSTM forward on non-emitting steps) +// is a pure speed optimization: recomputing g from the SAME committed state +// yields an identical g. So we recompute g for every active item each round +// (one batched LSTM forward) without changing any result. +// - State recovery / masking: we only copy out_state columns into `committed` +// for items that emitted this round; non-emitting items keep their prior +// committed columns, exactly as the per-item loop leaves committed unchanged +// on a blank. +void transducer_greedy_batch( + const PredictionNet& pred, const Joint& joint, + const std::vector>& encs, + const std::vector& T, + int enc_hidden, + const std::vector& durations, + int blank_id, int max_symbols, + std::vector>& ids, + std::vector>* toks) { + + const bool is_tdt = !durations.empty(); + const int N = (int)encs.size(); + assert((int)T.size() == N); + + const int Hj = joint.joint_hidden(); + const int Hp = pred.hidden_size(); + const int L = pred.num_layers(); + const int Vp = joint.V_plus(); + const int num_dur = (int)durations.size(); + // RNNT: argmax over the full V_plus. TDT: argmax over the token slice + // (vocab+1), durations live in [token_count, V_plus). Mirrors the oracle: + // tdt.cpp token_count = V_plus - num_dur + // rnnt.cpp token_count = V_plus + const int token_count = is_tdt ? (Vp - num_dur) : Vp; + + // Per-item precomputed encoder projection [T[n], Hj]. + std::vector> ep(N); + for (int n = 0; n < N; ++n) { + joint.precompute_enc_proj(encs[n], T[n], enc_hidden, ep[n]); + } + + // Outputs. + ids.assign(N, {}); + if (toks) toks->assign(N, {}); + + // Per-item host state. + std::vector t(N, 0); + std::vector active(N, 0); + std::vector last_token(N, -1); + std::vector have_token(N, 0); + // Per-frame symbol counter (TDT symbols_added / RNNT emitted). + std::vector sym_at_frame(N, 0); + for (int n = 0; n < N; ++n) active[n] = (T[n] > 0) ? 1 : 0; + + // Committed batched LSTM state, zero-initialized [L][Hp*N]. + BatchedPredState committed; + committed.h.assign((size_t)L, std::vector((size_t)Hp * N, 0.0f)); + committed.c.assign((size_t)L, std::vector((size_t)Hp * N, 0.0f)); + + // Scratch reused across rounds. + std::vector token_ids(N); + std::vector is_sos(N); + std::vector g; // [Hp*N] + BatchedPredState out_state; + std::vector enc_proj_gathered((size_t)Hj * N); + std::vector logits; // [Vp*N] + + auto any_active = [&]() { + for (int n = 0; n < N; ++n) if (active[n]) return true; + return false; + }; + + while (any_active()) { + // (1) ONE batched prediction step from the committed state for ALL items. + // Build inputs from committed last_token/have_token. Inactive items still + // need valid inputs (their output is ignored): SOS / blank_id. + for (int n = 0; n < N; ++n) { + is_sos[n] = have_token[n] ? 0 : 1; + token_ids[n] = have_token[n] ? last_token[n] : (int32_t)blank_id; + } + pred.step_batch(token_ids, is_sos, committed, g, out_state); + + // (2) Gather each active item's enc_proj row for its current frame and + // run ONE batched joint step -> logits[Vp*N]. + for (int n = 0; n < N; ++n) { + int tf = t[n]; + if (tf < 0) tf = 0; + if (tf > T[n] - 1) tf = T[n] - 1; // clamp (boundary / inactive) + if (T[n] <= 0) tf = 0; // no frames: harmless, ignored + const float* src = (T[n] > 0) ? (ep[n].data() + (size_t)tf * Hj) : ep[n].data(); + if (T[n] > 0) { + std::memcpy(&enc_proj_gathered[(size_t)n * Hj], src, (size_t)Hj * sizeof(float)); + } else { + std::memset(&enc_proj_gathered[(size_t)n * Hj], 0, (size_t)Hj * sizeof(float)); + } + } + joint.step_logits_batch(enc_proj_gathered.data(), g.data(), Hp, N, logits); + + // (3) Per-item rule from the oracle. + for (int n = 0; n < N; ++n) { + if (!active[n]) continue; + const float* lz = logits.data() + (size_t)n * Vp; + const int k = decode_argmax(lz, token_count); + + if (is_tdt) { + // --- tdt.cpp inner iteration --- + const int d_k = decode_argmax(lz + token_count, num_dur); + int skip = durations[d_k]; + + if (k != blank_id) { + ids[n].push_back((int32_t)k); + if (toks) { + const float conf = decode_max_prob_conf(lz, token_count, k); + (*toks)[n].push_back(TokenInfo{ (int32_t)k, (int32_t)t[n], conf, + (int32_t)skip }); + } + last_token[n] = (int32_t)k; + have_token[n] = 1; + // Commit this item's columns from out_state into committed. + for (int l = 0; l < L; ++l) { + std::memcpy(&committed.h[l][(size_t)n * Hp], + &out_state.h[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + std::memcpy(&committed.c[l][(size_t)n * Hp], + &out_state.c[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + } + } + // ALWAYS: symbols_added += 1; t += skip; need_loop = (skip == 0). + sym_at_frame[n] += 1; + t[n] += skip; + + // Inner loop in tdt.cpp continues iff (need_loop && symbols_added + // < max_symbols), i.e. (skip == 0 && sym_at_frame < max_symbols). + // Otherwise the frame is done. Post-inner: tdt.cpp does + // if (skip == 0) skip = 1; // dead for t (skip is local) + // if (symbols_added == max_symbols) t += 1; + const bool frame_done = !(skip == 0 && sym_at_frame[n] < max_symbols); + if (frame_done) { + if (sym_at_frame[n] == max_symbols) t[n] += 1; + sym_at_frame[n] = 0; + } + } else { + // --- rnnt.cpp inner iteration --- + if (k == blank_id) { + // Blank -> stop emitting at this frame, advance time. + t[n] += 1; + sym_at_frame[n] = 0; + } else { + ids[n].push_back((int32_t)k); + if (toks) { + const float conf = decode_max_prob_conf(lz, token_count, k); + (*toks)[n].push_back(TokenInfo{ (int32_t)k, (int32_t)t[n], conf, 1 }); + } + last_token[n] = (int32_t)k; + have_token[n] = 1; + for (int l = 0; l < L; ++l) { + std::memcpy(&committed.h[l][(size_t)n * Hp], + &out_state.h[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + std::memcpy(&committed.c[l][(size_t)n * Hp], + &out_state.c[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + } + sym_at_frame[n] += 1; + // emitted == max_symbols exits the inner while -> advance frame. + if (sym_at_frame[n] >= max_symbols) { + t[n] += 1; + sym_at_frame[n] = 0; + } + } + } + + // Recompute activity after time update. + active[n] = (t[n] < T[n]) ? 1 : 0; + } + } +} + +} // namespace pk diff --git a/src/transducer_batch.hpp b/src/transducer_batch.hpp new file mode 100644 index 0000000..a207236 --- /dev/null +++ b/src/transducer_batch.hpp @@ -0,0 +1,21 @@ +#pragma once +#include "prediction.hpp" +#include "joint.hpp" +#include "decode_types.hpp" +#include +#include +namespace pk { +// Batched greedy decode for N utterances. encs[n]: row-major [T[n], enc_hidden]. +// durations empty -> RNNT (advance-by-1); non-empty -> TDT (advance-by-duration). +// Outputs per item: ids[n], and (if toks != nullptr) TokenInfo[n]. Produces +// output bit-identical to per-item rnnt_greedy / tdt_greedy. +void transducer_greedy_batch( + const PredictionNet& pred, const Joint& joint, + const std::vector>& encs, + const std::vector& T, // [N] per-item frame counts + int enc_hidden, + const std::vector& durations, // empty=RNNT + int blank_id, int max_symbols, + std::vector>& ids, // OUT [N][.] + std::vector>* toks); // OUT [N][.] or nullptr +} // namespace pk diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 159887e..95fe32c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,8 @@ pk_add_test(test_joint) pk_add_test(test_joint_step_batch) pk_add_test(test_transducer_core) pk_add_test(test_tdt_greedy) +pk_add_test(test_transducer_greedy_batch) +pk_add_test(test_transducer_greedy_batch_rnnt) pk_add_test(test_timestamps_tokens) pk_add_test(test_timestamps) pk_add_test(test_transcribe_batch_ts) @@ -54,6 +56,7 @@ set_tests_properties(test_model_loader test_mel test_mel_gpu test_subsampling te test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch test_joint test_joint_step_batch test_transducer_core test_tdt_greedy + test_transducer_greedy_batch test_transducer_greedy_batch_rnnt test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b test_transcribe_ctc test_transcribe_rnnt test_transcribe_eou @@ -65,7 +68,9 @@ set_tests_properties(test_mel test_mel_gpu test_subsampling test_subsampling_bat test_conv_eou test_encoder test_encoder_batch test_encoder_eou test_streaming_encoder test_ctc test_prediction test_prediction_step test_prediction_step_batch test_joint test_joint_step_batch - test_transducer_core test_tdt_greedy test_timestamps_tokens + test_transducer_core test_tdt_greedy + test_transducer_greedy_batch test_transducer_greedy_batch_rnnt + test_timestamps_tokens test_timestamps test_transcribe_batch_ts test_tokenizer test_transcribe test_transcribe_speech test_transcribe_tdt test_transcribe_0_6b diff --git a/tests/test_transducer_greedy_batch.cpp b/tests/test_transducer_greedy_batch.cpp new file mode 100644 index 0000000..8438d74 --- /dev/null +++ b/tests/test_transducer_greedy_batch.cpp @@ -0,0 +1,54 @@ +#include "transducer_batch.hpp" +#include "tdt.hpp" +#include "prediction.hpp" +#include "joint.hpp" +#include "encoder.hpp" +#include "mel.hpp" +#include "audio_io.hpp" +#include "model_loader.hpp" +#include "parity.hpp" +#include +#include +#include +#include +#include +static bool toks_equal(const std::vector& a, const std::vector& b){ + if (a.size()!=b.size()) return false; + for (size_t i=0;i1e-4f) return false; + } + return true; +} +int main(){ + const char* gguf=std::getenv("PARAKEET_TEST_GGUF"); const char* base=std::getenv("PARAKEET_TEST_BASELINE"); + if(!gguf||!base){ std::fprintf(stderr,"env not set; skip\n"); return 77; } + pk::ModelLoader ml; if(!ml.load(gguf)) return 1; + const auto& cfg = ml.config(); + if (cfg.tdt_durations.empty()){ std::fprintf(stderr,"no TDT durations; skip\n"); return 77; } + // Compute a REAL speech mel so the decode actually emits tokens (the + // emit/commit/duration paths get exercised, not just all-blank). Falls back + // to nothing else: speech.wav ships in tests/fixtures. + pk::Audio audio; if(!pk::load_audio_16k_mono("tests/fixtures/speech.wav", audio)){ std::fprintf(stderr,"speech.wav load failed\n"); return 1; } + pk::MelFrontend melfe(ml); + std::vector mel; int n_mels=0, T0=0; + melfe.compute(audio.samples, mel, n_mels, T0); // row-major [n_mels, T0] + pk::Encoder enc(ml); + auto enc_row=[&](const std::vector& m,int Tn){ std::vector eo;int dm=0,to=0; enc.forward(m,n_mels,Tn,eo,dm,to); + std::vector r((size_t)to*dm); for(int t=0;t mel1((size_t)n_mels*T1); + for(int m=0;m r0,r1; + auto id0 = pk::tdt_greedy(pred,joint,e0,to0,dm,cfg.tdt_durations,blank,maxs,&r0); + auto id1 = pk::tdt_greedy(pred,joint,e1,to1,dm,cfg.tdt_durations,blank,maxs,&r1); + std::vector> encs={e0,e1}; std::vector Ts={to0,to1}; + std::vector> ids; std::vector> tk; + pk::transducer_greedy_batch(pred,joint,encs,Ts,dm,cfg.tdt_durations,blank,maxs,ids,&tk); + bool ok = ids.size()==2 && ids[0]==id0 && ids[1]==id1 && toks_equal(tk[0],r0) && toks_equal(tk[1],r1); + std::fprintf(stderr,"item0 ids %zu/%zu words; item1 ids %zu/%zu -> %s\n", ids[0].size(),id0.size(),ids[1].size(),id1.size(), ok?"OK":"FAIL"); + return ok?0:1; +} diff --git a/tests/test_transducer_greedy_batch_rnnt.cpp b/tests/test_transducer_greedy_batch_rnnt.cpp new file mode 100644 index 0000000..758e5bb --- /dev/null +++ b/tests/test_transducer_greedy_batch_rnnt.cpp @@ -0,0 +1,51 @@ +#include "transducer_batch.hpp" +#include "rnnt.hpp" +#include "prediction.hpp" +#include "joint.hpp" +#include "encoder.hpp" +#include "model_loader.hpp" +#include "parity.hpp" +#include +#include +#include +#include +#include +static bool toks_equal(const std::vector& a, const std::vector& b){ + if (a.size()!=b.size()) return false; + for (size_t i=0;i1e-4f) return false; + } + return true; +} +int main(){ + // RNNT-only model (the 110m anchor is TDT, not pure RNNT). Self-skip unless a + // dedicated RNNT GGUF is provided. + const char* gguf=std::getenv("PARAKEET_TEST_GGUF_RNNT"); const char* base=std::getenv("PARAKEET_TEST_BASELINE"); + if(!gguf||!base){ std::fprintf(stderr,"PARAKEET_TEST_GGUF_RNNT not set; skip\n"); return 77; } + pk::ModelLoader ml; if(!ml.load(gguf)) return 1; + const auto& cfg = ml.config(); + if (!cfg.tdt_durations.empty()){ std::fprintf(stderr,"model has TDT durations; not pure RNNT; skip\n"); return 77; } + std::vector mel; std::vector ms; + if(!pktest::load_baseline(base,"mel",mel,ms)) return 1; + const int n_mels=(int)ms[0], T0=(int)ms[1]; + pk::Encoder enc(ml); + auto enc_row=[&](const std::vector& m,int Tn){ std::vector eo;int dm=0,to=0; enc.forward(m,n_mels,Tn,eo,dm,to); + std::vector r((size_t)to*dm); for(int t=0;t mel1((size_t)n_mels*T1); + for(int m=0;m r0,r1; + auto id0 = pk::rnnt_greedy(pred,joint,e0,to0,dm,blank,maxs,&r0); + auto id1 = pk::rnnt_greedy(pred,joint,e1,to1,dm,blank,maxs,&r1); + std::vector> encs={e0,e1}; std::vector Ts={to0,to1}; + std::vector> ids; std::vector> tk; + std::vector no_dur{}; + pk::transducer_greedy_batch(pred,joint,encs,Ts,dm,no_dur,blank,maxs,ids,&tk); + bool ok = ids.size()==2 && ids[0]==id0 && ids[1]==id1 && toks_equal(tk[0],r0) && toks_equal(tk[1],r1); + std::fprintf(stderr,"item0 ids %zu/%zu words; item1 ids %zu/%zu -> %s\n", ids[0].size(),id0.size(),ids[1].size(),id1.size(), ok?"OK":"FAIL"); + return ok?0:1; +} From 782a8014677b5cc954ff64457997c686accb5edf Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:46:45 +0000 Subject: [PATCH 06/11] refactor(decode): extract commit_state lambda in transducer_greedy_batch; document max_symbols assumption Co-Authored-By: Claude Opus 4.8 (1M context) --- src/transducer_batch.cpp | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/transducer_batch.cpp b/src/transducer_batch.cpp index e387680..9fba293 100644 --- a/src/transducer_batch.cpp +++ b/src/transducer_batch.cpp @@ -82,6 +82,19 @@ void transducer_greedy_batch( return false; }; + // Commit item n's state columns (offset n*Hp, all L layers, both h and c) + // from out_state into committed. Used on every emit in both branches. + auto commit_state = [&](int n) { + for (int l = 0; l < L; ++l) { + std::memcpy(&committed.h[l][(size_t)n * Hp], + &out_state.h[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + std::memcpy(&committed.c[l][(size_t)n * Hp], + &out_state.c[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); + } + }; + + // Assumes max_symbols >= 1 (NeMo default 10). The per-round emit rule below + // mirrors the oracle inner loops, which never run at max_symbols==0. while (any_active()) { // (1) ONE batched prediction step from the committed state for ALL items. // Build inputs from committed last_token/have_token. Inactive items still @@ -96,7 +109,7 @@ void transducer_greedy_batch( // run ONE batched joint step -> logits[Vp*N]. for (int n = 0; n < N; ++n) { int tf = t[n]; - if (tf < 0) tf = 0; + if (tf < 0) tf = 0; // t[n] only grows from 0; lower clamp is defensive. Upper clamp matters for inactive/boundary columns. if (tf > T[n] - 1) tf = T[n] - 1; // clamp (boundary / inactive) if (T[n] <= 0) tf = 0; // no frames: harmless, ignored const float* src = (T[n] > 0) ? (ep[n].data() + (size_t)tf * Hj) : ep[n].data(); @@ -128,13 +141,7 @@ void transducer_greedy_batch( } last_token[n] = (int32_t)k; have_token[n] = 1; - // Commit this item's columns from out_state into committed. - for (int l = 0; l < L; ++l) { - std::memcpy(&committed.h[l][(size_t)n * Hp], - &out_state.h[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); - std::memcpy(&committed.c[l][(size_t)n * Hp], - &out_state.c[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); - } + commit_state(n); } // ALWAYS: symbols_added += 1; t += skip; need_loop = (skip == 0). sym_at_frame[n] += 1; @@ -164,12 +171,7 @@ void transducer_greedy_batch( } last_token[n] = (int32_t)k; have_token[n] = 1; - for (int l = 0; l < L; ++l) { - std::memcpy(&committed.h[l][(size_t)n * Hp], - &out_state.h[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); - std::memcpy(&committed.c[l][(size_t)n * Hp], - &out_state.c[l][(size_t)n * Hp], (size_t)Hp * sizeof(float)); - } + commit_state(n); sym_at_frame[n] += 1; // emitted == max_symbols exits the inner while -> advance frame. if (sym_at_frame[n] >= max_symbols) { From df4c85752a3f6298a093c1ef22e7061c2c432a6d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:49:34 +0000 Subject: [PATCH 07/11] feat(model): batched transducer decode in transcribe_pcm_batch[_with_timestamps] Co-Authored-By: Claude Opus 4.8 (1M context) --- src/model.cpp | 70 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index d210c3c..31e9c69 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -11,6 +11,7 @@ #include "joint.hpp" #include "tdt.hpp" #include "rnnt.hpp" +#include "transducer_batch.hpp" #include "transcription.hpp" #include "decode_types.hpp" #include "backend.hpp" @@ -157,10 +158,35 @@ std::vector Model::transcribe_16k_batch( std::vector valid_Tout; encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); - // 3. Per-item decode (each enc_out is [d_model, valid_Tout[b]]). + // 3. Decode (each enc_out is [d_model, valid_Tout[b]]). std::vector outs(mb.B); - for (int b = 0; b < mb.B; ++b) - outs[b] = decode_enc_out(loader_, enc_outs[b], d_model, valid_Tout[b], use_tdt); + if (use_tdt) { + // Batched transducer (TDT/RNNT) greedy decode: build per-item row-major + // [T, d_model] from the channels-first [d_model, T] encoder outputs. + std::vector> encs(mb.B); + std::vector Ts(mb.B); + for (int b = 0; b < mb.B; ++b) { + const int tb = valid_Tout[b]; + Ts[b] = tb; + std::vector& er = encs[b]; + er.resize((size_t)tb * d_model); + for (int t = 0; t < tb; ++t) + for (int c = 0; c < d_model; ++c) + er[(size_t)t * d_model + c] = enc_outs[b][(size_t)c * tb + t]; + } + PredictionNet pred(loader_); + Joint joint(loader_); + std::vector> ids; + pk::transducer_greedy_batch(pred, joint, encs, Ts, d_model, + cfg.tdt_durations, (int)cfg.blank_id, + (int)cfg.max_symbols, ids, nullptr); + for (int b = 0; b < mb.B; ++b) + outs[b] = detokenize(loader_.tokenizer_pieces(), ids[b]); + } else { + // CTC stays per-item (no autoregressive decode to batch). + for (int b = 0; b < mb.B; ++b) + outs[b] = decode_enc_out(loader_, enc_outs[b], d_model, valid_Tout[b], use_tdt); + } return outs; } @@ -276,9 +302,41 @@ std::vector Model::transcribe_16k_batch_with_timestamps( encoder.forward_batch(mb, enc_outs, d_model, Tout, valid_Tout); std::vector outs(mb.B); - for (int b = 0; b < mb.B; ++b) - outs[b] = decode_enc_out_with_timestamps( - loader_, enc_outs[b], d_model, valid_Tout[b], use_tdt, frame_sec); + if (use_tdt) { + // Batched transducer (TDT/RNNT) greedy decode with timestamps. Build + // per-item row-major [T, d_model] from channels-first [d_model, T]. + std::vector> encs(mb.B); + std::vector Ts(mb.B); + for (int b = 0; b < mb.B; ++b) { + const int tb = valid_Tout[b]; + Ts[b] = tb; + std::vector& er = encs[b]; + er.resize((size_t)tb * d_model); + for (int t = 0; t < tb; ++t) + for (int c = 0; c < d_model; ++c) + er[(size_t)t * d_model + c] = enc_outs[b][(size_t)c * tb + t]; + } + PredictionNet pred(loader_); + Joint joint(loader_); + std::vector> ids; + std::vector> toks; + pk::transducer_greedy_batch(pred, joint, encs, Ts, d_model, + cfg.tdt_durations, (int)cfg.blank_id, + (int)cfg.max_symbols, ids, &toks); + // Assemble each Transcription exactly as decode_enc_out_with_timestamps' + // transducer tail does. + for (int b = 0; b < mb.B; ++b) { + Transcription& result = outs[b]; + result.text = detokenize(loader_.tokenizer_pieces(), ids[b]); + result.words = group_words(toks[b], loader_.tokenizer_pieces(), frame_sec); + result.tokens = std::move(toks[b]); + } + } else { + // CTC stays per-item (not a transducer; no autoregressive decode). + for (int b = 0; b < mb.B; ++b) + outs[b] = decode_enc_out_with_timestamps( + loader_, enc_outs[b], d_model, valid_Tout[b], use_tdt, frame_sec); + } return outs; } From 52cffea3716970de3ee8e71a076d1d6a660aa355 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 22:56:28 +0000 Subject: [PATCH 08/11] bench: add bench-decode (batched vs serial transducer decode timing) Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/cli/main.cpp | 209 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 207 insertions(+), 2 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 2988492..dbb7b27 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -5,7 +5,16 @@ #include "audio_io.hpp" #include "streaming.hpp" #include "transcription.hpp" -#include "ggml_graph.hpp" // pk::set_num_threads +#include "ggml_graph.hpp" // pk::set_num_threads, pk::global_backend +#include "backend.hpp" // pk::ensure_weights_realized +#include "encoder.hpp" +#include "prediction.hpp" +#include "joint.hpp" +#include "tdt.hpp" +#include "rnnt.hpp" +#include "transducer_batch.hpp" +#include "mel.hpp" +#include "mel_gpu.hpp" #include "ggml.h" #include "gguf.h" #include @@ -851,6 +860,198 @@ static int cmd_bench_batch(int argc, char** argv) { return 0; } +// --------------------------------------------------------------------------- +// bench-decode: encode ONE clip once, then time DECODE only -- serial (N +// separate tdt_greedy/rnnt_greedy calls over N copies of the encoder output) +// vs batched (transducer_greedy_batch over the same N copies) -- at several +// batch sizes. Reports decode wall-clock and the batched/serial speedup so the +// GPU win from batched decode can be measured in isolation from the encoder. +// +// The encoder cost is paid once and excluded; only the transducer decode loop +// is timed. Each rep is averaged over R repetitions (best/min recorded). The +// b=0 batched ids are compared against the serial ids (same clip, decode is +// deterministic) as a correctness sanity check. +// --------------------------------------------------------------------------- +static int cmd_bench_decode(int argc, char** argv) { + std::string model, audio; + std::string batch_sizes_str = "1,4,8,16"; + int threads = 0; // 0 == unset -> use the persistent-backend default + int reps = 5; + for (int i = 0; i < argc; ++i) { + if (std::strcmp(argv[i], "--model") == 0 && i + 1 < argc) { + model = argv[++i]; + } else if (std::strcmp(argv[i], "--audio") == 0 && i + 1 < argc) { + audio = argv[++i]; + } else if (std::strcmp(argv[i], "--batch-sizes") == 0 && i + 1 < argc) { + batch_sizes_str = argv[++i]; + } else if (std::strcmp(argv[i], "--threads") == 0 && i + 1 < argc) { + threads = std::atoi(argv[++i]); + } else if (std::strcmp(argv[i], "--reps") == 0 && i + 1 < argc) { + reps = std::atoi(argv[++i]); + } + } + if (model.empty() || audio.empty()) { + std::fprintf(stderr, + "usage: parakeet-cli bench-decode --model --audio " + "[--batch-sizes 1,4,8,16] [--threads N] [--reps R]\n"); + return 2; + } + if (reps < 1) reps = 1; + + // Parse --batch-sizes (comma-separated positive ints). + std::vector batch_sizes; + { + std::stringstream ss(batch_sizes_str); + std::string tok; + while (std::getline(ss, tok, ',')) { + size_t b = tok.find_first_not_of(" \t"); + if (b == std::string::npos) continue; + size_t e = tok.find_last_not_of(" \t"); + int v = std::atoi(tok.substr(b, e - b + 1).c_str()); + if (v > 0) batch_sizes.push_back(v); + } + } + if (batch_sizes.empty()) { + std::fprintf(stderr, + "parakeet-cli bench-decode: no valid --batch-sizes (want e.g. 1,4,8,16)\n"); + return 2; + } + + if (threads > 0) pk::set_num_threads(threads); + int reported_threads = threads > 0 ? threads : 8; // kDefaultThreads + + // Load the model components over the lower-level loader (we need the encoder + // / prediction / joint pieces, not the high-level Model::transcribe path). + pk::ModelLoader ml; + if (!ml.load(model)) { + std::fprintf(stderr, "parakeet-cli bench-decode: failed to load model %s\n", + model.c_str()); + return 1; + } + pk::ensure_weights_realized(ml); + pk::Encoder enc(ml); + pk::PredictionNet pred(ml); + pk::Joint joint(ml); + const auto& cfg = ml.config(); + const int blank = (int)cfg.blank_id; + const int maxs = (int)cfg.max_symbols; + const std::vector durations = cfg.tdt_durations; + + // Load the WAV. + pk::Audio a; + if (!pk::load_audio_16k_mono(audio, a)) { + std::fprintf(stderr, "parakeet-cli bench-decode: failed to load audio %s\n", + audio.c_str()); + return 1; + } + + // Mel front end (GpuMel on a non-CPU backend, else FFT MelFrontend), exactly + // as model.cpp's transcribe path does. + std::vector feats; + int n_mels = 0, T = 0; + if (std::string(pk::global_backend().device_name()) != "cpu") { + pk::GpuMel gmel(ml); + gmel.compute(a.samples, feats, n_mels, T); + } else { + pk::MelFrontend mel(ml); + mel.compute(a.samples, feats, n_mels, T); + } + + // Encoder -> enc_out [d_model, Tout] (channels-first); transpose to row-major + // enc_row [Tout, d_model] as the decoders expect. + std::vector enc_out; + int dm = 0, Tout = 0; + enc.forward(feats, n_mels, T, enc_out, dm, Tout); + std::vector enc_row((size_t)Tout * dm); + for (int t = 0; t < Tout; ++t) + for (int c = 0; c < dm; ++c) + enc_row[(size_t)t * dm + c] = enc_out[(size_t)c * Tout + t]; + + const bool use_tdt = !durations.empty(); + auto decode_serial_one = [&]() -> std::vector { + return use_tdt + ? pk::tdt_greedy(pred, joint, enc_row, Tout, dm, durations, blank, maxs, nullptr) + : pk::rnnt_greedy(pred, joint, enc_row, Tout, dm, blank, maxs, nullptr); + }; + + using clock = std::chrono::steady_clock; + auto ms_since = [](clock::time_point t0) { + return std::chrono::duration(clock::now() - t0).count(); + }; + + // Warm up (untimed): realize weights / CUDA kernels for both paths. + std::vector serial_ref = decode_serial_one(); + { + std::vector> encs1{enc_row}; + std::vector Ts1{Tout}; + std::vector> ids1; + pk::transducer_greedy_batch(pred, joint, encs1, Ts1, dm, durations, + blank, maxs, ids1, nullptr); + } + + struct Row { int B; double serial_ms; double batched_ms; double speedup; + double serial_cps; double batched_cps; }; + std::vector rows; + rows.reserve(batch_sizes.size()); + bool sanity_ok = true; + + for (int B : batch_sizes) { + std::vector> encs((size_t)B, enc_row); + std::vector Ts((size_t)B, Tout); + + // SERIAL: B separate single-clip decodes, best of R reps. + double serial_ms = 1e300; + for (int r = 0; r < reps; ++r) { + auto t0 = clock::now(); + for (int b = 0; b < B; ++b) (void)decode_serial_one(); + serial_ms = std::min(serial_ms, ms_since(t0)); + } + + // BATCHED: one transducer_greedy_batch over the B copies, best of R reps. + double batched_ms = 1e300; + std::vector> ids_last; + for (int r = 0; r < reps; ++r) { + std::vector> ids; + auto t0 = clock::now(); + pk::transducer_greedy_batch(pred, joint, encs, Ts, dm, durations, + blank, maxs, ids, nullptr); + batched_ms = std::min(batched_ms, ms_since(t0)); + ids_last = std::move(ids); + } + + // Sanity: batched ids[0] must equal the serial decode of the same clip. + if (!ids_last.empty() && ids_last[0] != serial_ref) { + sanity_ok = false; + std::fprintf(stderr, + "parakeet-cli bench-decode: WARN B=%d batched ids[0] != serial " + "(%zu vs %zu tokens) -- batched decode may be buggy\n", + B, ids_last[0].size(), serial_ref.size()); + } + + double speedup = batched_ms > 0.0 ? serial_ms / batched_ms : 0.0; + double serial_cps = serial_ms > 0.0 ? (double)B / (serial_ms / 1000.0) : 0.0; + double batched_cps = batched_ms > 0.0 ? (double)B / (batched_ms / 1000.0) : 0.0; + rows.push_back({B, serial_ms, batched_ms, speedup, serial_cps, batched_cps}); + } + + // Human-readable table to stderr. + std::fprintf(stderr, + "\nbench-decode: clip Tout=%d frames, d_model=%d, decoder=%s, threads=%d, " + "reps=%d (best-of), backend=%s\n", + Tout, dm, use_tdt ? "tdt" : "rnnt", reported_threads, reps, + pk::global_backend().device_name()); + std::fprintf(stderr, " %-6s %-12s %-12s %-10s %-14s %-14s\n", + "B", "serial_ms", "batched_ms", "speedup", "serial_cps", "batched_cps"); + for (const Row& r : rows) { + std::fprintf(stderr, " %-6d %-12.2f %-12.2f %-10.2f %-14.1f %-14.1f\n", + r.B, r.serial_ms, r.batched_ms, r.speedup, + r.serial_cps, r.batched_cps); + } + std::fprintf(stderr, " sanity (batched ids[0]==serial): %s\n", + sanity_ok ? "OK" : "MISMATCH (see WARN above)"); + return 0; +} + // Run a subcommand, then free the process-global backend while the GPU driver is // still alive (the subcommand's local Model is already destroyed by the time it // returns, releasing its device weight buffer). Avoids the CUDA "driver shutting @@ -870,6 +1071,8 @@ int main(int argc, char** argv) { return run_and_shutdown(cmd_quantize, argc - 2, argv + 2); if (argc >= 2 && std::strcmp(argv[1], "bench-batch") == 0) return run_and_shutdown(cmd_bench_batch, argc - 2, argv + 2); + if (argc >= 2 && std::strcmp(argv[1], "bench-decode") == 0) + return run_and_shutdown(cmd_bench_decode, argc - 2, argv + 2); if (argc >= 2 && std::strcmp(argv[1], "bench") == 0) return run_and_shutdown(cmd_bench, argc - 2, argv + 2); std::fprintf(stderr, @@ -882,6 +1085,8 @@ int main(int argc, char** argv) { " parakeet-cli bench --model --manifest " "[--decoder ctc|tdt] [--threads N] [--json ]\n" " parakeet-cli bench-batch --model --manifest " - "[--decoder ctc|tdt] [--threads N] [--batch-sizes 1,4,8] [--json ]\n"); + "[--decoder ctc|tdt] [--threads N] [--batch-sizes 1,4,8] [--json ]\n" + " parakeet-cli bench-decode --model --audio " + "[--batch-sizes 1,4,8,16] [--threads N] [--reps R]\n"); return 2; } From 85feda903405bc0096f4c1d792e2cbfb08888868 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 23:07:27 +0000 Subject: [PATCH 09/11] perf(decode): cache prediction-net g across rounds in transducer_greedy_batch (bit-exact; recovers B=1, fewer LSTM calls) Co-Authored-By: Claude Opus 4.8 (1M context) --- src/transducer_batch.cpp | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/transducer_batch.cpp b/src/transducer_batch.cpp index 9fba293..2ea0e46 100644 --- a/src/transducer_batch.cpp +++ b/src/transducer_batch.cpp @@ -64,6 +64,15 @@ void transducer_greedy_batch( std::vector sym_at_frame(N, 0); for (int n = 0; n < N; ++n) active[n] = (T[n] > 0) ? 1 : 0; + // Per-item prediction-net cache validity. 0 = stale (g column must be + // recomputed before the joint), 1 = fresh (the persistent `g` buffer still + // holds this item's correct column from a prior round). All stale initially + // so the first round runs pred from SOS. Mirrors tdt.cpp/rnnt.cpp `g_valid`: + // set false on emit (committed state advanced), reused otherwise. Bit-exact: + // recomputing g from an UNCHANGED committed state yields an identical g, so + // skipping the pred step on all-valid rounds reuses the same values. + std::vector g_valid(N, 0); + // Committed batched LSTM state, zero-initialized [L][Hp*N]. BatchedPredState committed; committed.h.assign((size_t)L, std::vector((size_t)Hp * N, 0.0f)); @@ -96,14 +105,28 @@ void transducer_greedy_batch( // Assumes max_symbols >= 1 (NeMo default 10). The per-round emit rule below // mirrors the oracle inner loops, which never run at max_symbols==0. while (any_active()) { - // (1) ONE batched prediction step from the committed state for ALL items. - // Build inputs from committed last_token/have_token. Inactive items still - // need valid inputs (their output is ignored): SOS / blank_id. + // (1) Batched prediction step from the committed state, but only when + // some active item's cache is stale. If every active item already has a + // fresh `g` column (no emit since it was last computed), the persistent + // `g` buffer still holds the correct values and we skip the LSTM forward + // entirely — the same per-item caching as tdt.cpp/rnnt.cpp, batched. + bool any_stale = false; for (int n = 0; n < N; ++n) { - is_sos[n] = have_token[n] ? 0 : 1; - token_ids[n] = have_token[n] ? last_token[n] : (int32_t)blank_id; + if (active[n] && !g_valid[n]) { any_stale = true; break; } + } + if (any_stale) { + // Build inputs from committed last_token/have_token. Inactive items + // still need valid inputs (their output is ignored): SOS / blank_id. + for (int n = 0; n < N; ++n) { + is_sos[n] = have_token[n] ? 0 : 1; + token_ids[n] = have_token[n] ? last_token[n] : (int32_t)blank_id; + } + pred.step_batch(token_ids, is_sos, committed, g, out_state); + // Every active item's g column is now fresh. (Recomputing a + // non-emitter's g from its unchanged committed state reproduces its + // cached value bit-for-bit, so marking it valid is exact.) + for (int n = 0; n < N; ++n) if (active[n]) g_valid[n] = 1; } - pred.step_batch(token_ids, is_sos, committed, g, out_state); // (2) Gather each active item's enc_proj row for its current frame and // run ONE batched joint step -> logits[Vp*N]. @@ -142,6 +165,7 @@ void transducer_greedy_batch( last_token[n] = (int32_t)k; have_token[n] = 1; commit_state(n); + g_valid[n] = 0; // committed state advanced -> g stale next round } // ALWAYS: symbols_added += 1; t += skip; need_loop = (skip == 0). sym_at_frame[n] += 1; @@ -172,6 +196,7 @@ void transducer_greedy_batch( last_token[n] = (int32_t)k; have_token[n] = 1; commit_state(n); + g_valid[n] = 0; // committed state advanced -> g stale next round sym_at_frame[n] += 1; // emitted == max_symbols exits the inner while -> advance frame. if (sym_at_frame[n] >= max_symbols) { From c3161d229e8fd1ed2e8596d5ad395a13ded31fc6 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 23:08:23 +0000 Subject: [PATCH 10/11] refactor(model): extract batch_enc_to_row_major helper (dedup batched-decode transpose) Co-Authored-By: Claude Opus 4.8 (1M context) --- src/model.cpp | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index 31e9c69..fcaa594 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -143,6 +143,24 @@ static MelBatch build_mel_batch(const ModelLoader& loader, return mb; } +// Transpose the batched encoder outputs (channels-first [d_model, valid_Tout[b]]) +// into per-item row-major [valid_Tout[b], d_model] for the transducer decoder. +static void batch_enc_to_row_major(const std::vector>& enc_outs, + const std::vector& valid_Tout, int d_model, + std::vector>& encs, + std::vector& Ts) { + const int B = (int)enc_outs.size(); + encs.assign(B, {}); Ts.assign(B, 0); + for (int b = 0; b < B; ++b) { + const int tb = valid_Tout[b]; + Ts[b] = tb; + encs[b].resize((size_t)tb * d_model); + for (int t = 0; t < tb; ++t) + for (int c = 0; c < d_model; ++c) + encs[b][(size_t)t * d_model + c] = enc_outs[b][(size_t)c * tb + t]; + } +} + std::vector Model::transcribe_16k_batch( const std::vector>& pcms16k, Decoder decoder) const { const ParakeetConfig& cfg = loader_.config(); @@ -163,17 +181,9 @@ std::vector Model::transcribe_16k_batch( if (use_tdt) { // Batched transducer (TDT/RNNT) greedy decode: build per-item row-major // [T, d_model] from the channels-first [d_model, T] encoder outputs. - std::vector> encs(mb.B); - std::vector Ts(mb.B); - for (int b = 0; b < mb.B; ++b) { - const int tb = valid_Tout[b]; - Ts[b] = tb; - std::vector& er = encs[b]; - er.resize((size_t)tb * d_model); - for (int t = 0; t < tb; ++t) - for (int c = 0; c < d_model; ++c) - er[(size_t)t * d_model + c] = enc_outs[b][(size_t)c * tb + t]; - } + std::vector> encs; + std::vector Ts; + batch_enc_to_row_major(enc_outs, valid_Tout, d_model, encs, Ts); PredictionNet pred(loader_); Joint joint(loader_); std::vector> ids; @@ -305,17 +315,9 @@ std::vector Model::transcribe_16k_batch_with_timestamps( if (use_tdt) { // Batched transducer (TDT/RNNT) greedy decode with timestamps. Build // per-item row-major [T, d_model] from channels-first [d_model, T]. - std::vector> encs(mb.B); - std::vector Ts(mb.B); - for (int b = 0; b < mb.B; ++b) { - const int tb = valid_Tout[b]; - Ts[b] = tb; - std::vector& er = encs[b]; - er.resize((size_t)tb * d_model); - for (int t = 0; t < tb; ++t) - for (int c = 0; c < d_model; ++c) - er[(size_t)t * d_model + c] = enc_outs[b][(size_t)c * tb + t]; - } + std::vector> encs; + std::vector Ts; + batch_enc_to_row_major(enc_outs, valid_Tout, d_model, encs, Ts); PredictionNet pred(loader_); Joint joint(loader_); std::vector> ids; From da38ea1bbe44629e384c949d93fe300c17064d8b Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 31 May 2026 23:13:29 +0000 Subject: [PATCH 11/11] docs(decode): correct header comment to describe the g_valid cache Co-Authored-By: Claude Opus 4.8 (1M context) --- src/transducer_batch.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transducer_batch.cpp b/src/transducer_batch.cpp index 2ea0e46..4320c7b 100644 --- a/src/transducer_batch.cpp +++ b/src/transducer_batch.cpp @@ -12,10 +12,10 @@ namespace pk { // EXACT per-item rule from the oracle to each item independently. // // Parity rationale: -// - The per-item `g_valid` cache (skip the LSTM forward on non-emitting steps) -// is a pure speed optimization: recomputing g from the SAME committed state -// yields an identical g. So we recompute g for every active item each round -// (one batched LSTM forward) without changing any result. +// - The per-item `g_valid` cache (skip the batched LSTM forward on rounds where +// no active item emitted) is a pure speed optimization: recomputing g from the +// SAME committed state yields an identical g, so reusing the cached g column is +// byte-identical. Mirrors the per-item `g_valid` in tdt.cpp/rnnt.cpp. // - State recovery / masking: we only copy out_state columns into `committed` // for items that emitted this round; non-emitting items keep their prior // committed columns, exactly as the per-item loop leaves committed unchanged