From 208917a21003f9e896c7e89d225cefc3aeec84c6 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Sat, 10 Aug 2024 18:24:15 +0800 Subject: [PATCH 1/3] Implement `start_pos` per query for batch interface --- evals/benchmark_helper.cc | 3 +- gemma/gemma-inl.h | 126 ++++++++++++++++++++++++-------------- gemma/gemma.cc | 13 ++-- gemma/gemma.h | 4 +- 4 files changed, 91 insertions(+), 55 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 92983372..9f58ebf2 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -171,7 +171,8 @@ std::vector> GemmaEnv::BatchQueryModel2( gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; inference_args_.CopyTo(runtime_config_); - model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, + model_->GenerateBatch(runtime_config_, prompts, + std::vector(num_queries, 0), KVCaches(&kv_caches_[0], num_queries), timing_info); return res; } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index c1c5ba43..96e8a63a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -233,7 +233,7 @@ class GemmaAttention { // Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices // and we later copy KV from q to KVCache. Otherwise, a second MatMul writes // KV directly to KVCache. - HWY_NOINLINE void ComputeQKV(const size_t batch_start, + HWY_NOINLINE void ComputeQKV(const MultiplePositions& batch_start, const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.QKV"); // For the computation of Q, K, and V, it is useful to remember that @@ -255,9 +255,9 @@ class GemmaAttention { // Single query and no wraparound means we can use a matmul and write // directly into the KV cache with a stride of kCachePosSize. if (num_queries_ == 1 && - batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) { + batch_start[0] + num_tokens_ <= div_seq_len_.GetDivisor()) { const size_t kv_ofs = - batch_start * kCachePosSize + layer_ * kCacheLayerSize; + batch_start[0] * kCachePosSize + layer_ * kCacheLayerSize; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; MatMul_4x4( @@ -275,7 +275,7 @@ class GemmaAttention { const size_t batch_idx = interleaved_idx / num_queries_; KVCache& kv_cache = kv_caches_[query_idx]; const size_t cache_pos = - div_seq_len_.Remainder(batch_start + batch_idx); + div_seq_len_.Remainder(batch_start[query_idx] + batch_idx); const size_t kv_offset = cache_pos * kCachePosSize + layer_ * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; @@ -295,7 +295,7 @@ class GemmaAttention { const size_t interleaved_idx = task / kKVHeads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = batch_start + batch_idx; + const size_t pos = batch_start[query_idx] + batch_idx; const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t kv_offset = cache_pos * kCachePosSize + layer_ * kCacheLayerSize + @@ -374,7 +374,7 @@ class GemmaAttention { } } - HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start, + HWY_NOINLINE void DotSoftmaxWeightedSum(const MultiplePositions& batch_start, const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.DotSoftmax"); GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); @@ -398,7 +398,7 @@ class GemmaAttention { activations_.q.Batch(interleaved_idx) + head * kQStride; // Apply rope and scaling to Q. - const size_t pos = batch_start + batch_idx; + const size_t pos = batch_start[query_idx] + batch_idx; PositionalEncodingQK(q, pos, layer_, kQueryScale, q); const size_t start_pos = StartPos(pos, layer_); @@ -440,13 +440,12 @@ class GemmaAttention { } public: - GemmaAttention(size_t interleaved_start, size_t num_tokens, + GemmaAttention(const MultiplePositions& interleaved_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, hwy::ThreadPool& pool) - : interleaved_start_(interleaved_start), - num_tokens_(num_tokens), + : num_tokens_(num_tokens), num_queries_(num_queries), layer_(layer), activations_(activations), @@ -454,12 +453,22 @@ class GemmaAttention { div_seq_len_(div_seq_len), kv_caches_(kv_caches), pool_(pool) { - HWY_DASSERT(interleaved_start_ % num_queries_ == 0); + HWY_DASSERT(std::all_of(interleaved_start.cbegin(), + interleaved_start.cend(), [this](size_t pos) { + return pos % num_queries_ == 0; + })); HWY_DASSERT(num_queries_ <= kv_caches_.size()); + + batch_start_.reserve(interleaved_start.size()); + for (auto i = interleaved_start.cbegin(); + i != interleaved_start.cend(); ++i) { + batch_start_.emplace_back(*i / num_queries_); + } } HWY_INLINE void operator()() { - const size_t batch_start = interleaved_start_ / num_queries_; + const MultiplePositions batch_start(batch_start_.data(), + batch_start_.size()); const size_t num_interleaved = num_tokens_ * num_queries_; ComputeQKV(batch_start, num_interleaved); @@ -468,7 +477,7 @@ class GemmaAttention { } private: - const size_t interleaved_start_; + std::vector batch_start_; const size_t num_tokens_; const size_t num_queries_; const size_t layer_; @@ -480,7 +489,8 @@ class GemmaAttention { }; template -HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, +HWY_NOINLINE void Attention(LayerAttentionType type, + const MultiplePositions& interleaved_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, @@ -495,7 +505,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, // this code for non-Griffin models. if constexpr (TConfig::kGriffinLayers > 0) { HWY_ASSERT(num_queries == 1); - GriffinRecurrent(interleaved_start, num_tokens, num_queries, + GriffinRecurrent(interleaved_start[0], num_tokens, num_queries, layer, activations, layer_weights, kv_caches, pool); } @@ -599,10 +609,10 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) { template HWY_NOINLINE void TransformerLayer( - size_t num_tokens, size_t num_queries, size_t pos, size_t layer, - const CompressedLayer* layer_weights, Activations& activations, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - hwy::ThreadPool& pool) { + size_t num_tokens, size_t num_queries, const MultiplePositions& pos, + size_t layer, const CompressedLayer* layer_weights, + Activations& activations, const hwy::Divisor& div_seq_len, + const KVCaches& kv_caches, hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_interleaved = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; @@ -688,7 +698,8 @@ class PrefillState { template HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts, - const size_t prefill_per_query, const size_t pos, + const size_t prefill_per_query, + const MultiplePositions& pos, const size_t query_idx_start, const CompressedWeights& weights, const RuntimeConfig& runtime_config, @@ -719,14 +730,19 @@ class PrefillState { HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start); for (size_t ti = 0; ti < tbatch_size; ++ti) { const int token = prompts[qi][tbatch_start + ti]; - EmbedToken(token, ti, pos + ti, weights, activations.x); + EmbedToken(token, ti, pos[qi] + ti, weights, + activations.x); } + const size_t tbatch_pos = pos[qi] + tbatch_start; + const MultiplePositions prefill_tbatch_pos(&tbatch_pos, + kPrefillQueries); + // Transformer with one batch of tokens from a single query. for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const auto* layer_weights = weights.GetLayer(layer); TransformerLayer(tbatch_size, kPrefillQueries, - pos + tbatch_start, layer, + prefill_tbatch_pos, layer, layer_weights, activations, div_seq_len, prefill_kv_caches, inner_pool); } @@ -735,7 +751,7 @@ class PrefillState { for (size_t ti = 0; ti < tbatch_size; ++ti) { const int token = prompts[qi][tbatch_start + ti]; runtime_config.StreamToken(query_idx_start + qi, - pos + tbatch_start + ti, token, 0.0f); + tbatch_pos + ti, token, 0.0f); } } // for tbatch_start }); @@ -749,7 +765,7 @@ class PrefillState { // `num_tokens == 1`. template HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, - size_t num_queries, size_t pos, + size_t num_queries, const MultiplePositions& pos, const CompressedWeights& weights, Activations& activations, const hwy::Divisor& div_seq_len, @@ -759,14 +775,15 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { const size_t query_idx = token_idx % num_queries; - const size_t logical_pos = (pos + token_idx) / num_queries; + const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries; const float token_f = tokens[token_idx]; layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1); } } constexpr size_t kModelDim = TConfig::kModelDim; for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { - EmbedToken(tokens[token_idx], token_idx, pos, weights, + const size_t query_idx = token_idx % num_queries; + EmbedToken(tokens[token_idx], token_idx, pos[query_idx], weights, activations.x); } @@ -778,7 +795,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { - const size_t logical_pos = (pos + token_idx) / num_queries; + const size_t query_idx = token_idx % num_queries; + const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries; layers_output(token_idx % num_queries, logical_pos, "blocks", layer, activations.x.Batch(token_idx), kModelDim); } @@ -790,7 +808,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { const size_t query_idx = token_idx % num_queries; - const size_t logical_pos = (pos + token_idx) / num_queries; + const size_t logical_pos = (pos[query_idx] + token_idx) / num_queries; layers_output(query_idx, logical_pos, "final_norm", -1, activations.x.Batch(token_idx), kModelDim); } @@ -897,9 +915,10 @@ class TokenStreamer { template void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, - const MultiplePromptsTokens& prompts, const size_t pos, - const size_t query_idx_start, const KVCaches& kv_caches, - PerClusterPools& pools, TimingInfo& timing_info) { + const MultiplePromptsTokens& prompts, + const MultiplePositions& pos, const size_t query_idx_start, + const KVCaches& kv_caches, PerClusterPools& pools, + TimingInfo& timing_info) { constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kVocabSize = TConfig::kVocabSize; const CompressedWeights& weights = @@ -921,10 +940,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); - if (pos >= max_tokens) { - fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, - max_tokens); - return; + for (auto i = pos.cbegin(); i != pos.cend(); ++i) { + if (*i >= max_tokens) { + fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", *i, + max_tokens); + return; + } } // If no sample_func is provided, we use top-k sampling. @@ -953,7 +974,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); } - size_t interleaved_pos = (pos + prefill_per_query) * num_queries; + std::vector interleaved_pos(pos.size()); + std::transform( + pos.cbegin(), pos.cend(), interleaved_pos.begin(), + [&](size_t v) { return (v + prefill_per_query) * num_queries; }); // Storage for the last generated token from each query, passed to the next // Transformer() call. @@ -972,10 +996,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, gen_per_query < HWY_MIN(max_tokens, max_generated_tokens); ++gen_per_query) { // Decode: generate one token for each query. - Transformer(gen_tokens.data(), /*num_tokens=*/1, num_queries, - interleaved_pos, weights, activations, div_seq_len, - kv_caches, pool, runtime_config.layers_output); - interleaved_pos += num_queries; + Transformer( + gen_tokens.data(), /*num_tokens=*/1, num_queries, + MultiplePositions(interleaved_pos.data(), interleaved_pos.size()), + weights, activations, div_seq_len, kv_caches, pool, + runtime_config.layers_output); + for (auto& v: interleaved_pos) { + v += num_queries; + } bool all_queries_eos = true; PROFILER_ZONE("Gen.Embedding"); @@ -1016,19 +1044,22 @@ void GenerateSingleT(const ByteStorageT& weights_u8, activations.Allocate(num_queries); const MultiplePromptsTokens prompts(&prompt, num_queries); + const MultiplePositions positions(&pos, num_queries); const KVCaches kv_caches{&kv_cache, num_queries}; - GenerateT(weights_u8, activations, runtime_config, prompts, pos, - qbatch_start, kv_caches, pools, timing_info); + GenerateT(weights_u8, activations, runtime_config, prompts, + positions, qbatch_start, kv_caches, pools, timing_info); } template void GenerateBatchT(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const MultiplePromptsTokens& prompts, size_t pos, + const MultiplePromptsTokens& prompts, + const MultiplePositions& pos, const KVCaches& kv_caches, PerClusterPools& pools, TimingInfo& timing_info) { - HWY_ASSERT(prompts.size() == kv_caches.size()); + HWY_ASSERT(prompts.size() == pos.size() && + prompts.size() == kv_caches.size()); // Griffin does not support query batching. const size_t max_qbatch_size = (TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size; @@ -1044,9 +1075,10 @@ void GenerateBatchT(const ByteStorageT& weights_u8, HWY_MIN(num_queries - qbatch_start, max_qbatch_size); const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start], qbatch_size); + const MultiplePositions qbatch_pos(&pos[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(weights_u8, activations, runtime_config, qbatch_prompts, - pos, qbatch_start, qbatch_kv, pools, timing_info); + qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info); } } @@ -1067,8 +1099,8 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers) void GenerateBatch( // NOLINT(misc-definitions-in-headers) GEMMA_CONFIG, const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, - size_t pos, const KVCaches& kv_caches, PerClusterPools& pools, - TimingInfo& timing_info) { + const MultiplePositions& pos, const KVCaches& kv_caches, + PerClusterPools& pools, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) (weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ce5f07e4..57275142 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -66,7 +66,8 @@ Gemma::~Gemma() { TimingInfo& timing_info); \ extern void GenerateBatch(CONFIGT, const ByteStorageT& weights_u8, \ const RuntimeConfig& runtime_config, \ - const MultiplePromptsTokens& prompts, size_t pos, \ + const MultiplePromptsTokens& prompts, \ + const MultiplePositions& pos, \ const KVCaches& kv_caches, PerClusterPools& pools, \ TimingInfo& timing_info); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); @@ -87,9 +88,9 @@ template struct GenerateBatchT { void operator()(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const MultiplePromptsTokens& prompts, size_t pos, - const KVCaches& kv_caches, PerClusterPools& pools, - TimingInfo& timing_info) const { + const MultiplePromptsTokens& prompts, + const MultiplePositions& pos, const KVCaches& kv_caches, + PerClusterPools& pools, TimingInfo& timing_info) const { GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info); } @@ -109,8 +110,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, - size_t start_pos, const KVCaches& kv_caches, - TimingInfo& timing_info) { + const MultiplePositions& start_pos, + const KVCaches& kv_caches, TimingInfo& timing_info) { pools_.StartSpinning(); CallForModelAndWeight(info_.model, info_.weight, weights_u8_, diff --git a/gemma/gemma.h b/gemma/gemma.h index 990449ec..b1469a6a 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -143,6 +143,7 @@ struct TimingInfo { using PromptTokens = hwy::Span; using MultiplePromptsTokens = hwy::Span; +using MultiplePositions = hwy::Span; using KVCaches = hwy::Span; class Gemma { @@ -164,7 +165,8 @@ class Gemma { size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info); void GenerateBatch(const RuntimeConfig& runtime_config, - const MultiplePromptsTokens& prompts, size_t start_pos, + const MultiplePromptsTokens& prompts, + const MultiplePositions& start_pos, const KVCaches& kv_caches, TimingInfo& timing_info); private: From 477ffc126367a3f2b6afd2e334742e649cb98db2 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Sat, 10 Aug 2024 19:55:10 +0800 Subject: [PATCH 2/3] Fix build issues when tests are enabled --- BUILD.bazel | 4 ++-- CMakeLists.txt | 2 +- ops/{matvec_test.cc => gemma_matvec_test.cc} | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename ops/{matvec_test.cc => gemma_matvec_test.cc} (99%) diff --git a/BUILD.bazel b/BUILD.bazel index 77f199b9..a2b647ab 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -58,10 +58,10 @@ cc_test( ) cc_test( - name = "matvec_test", + name = "gemma_matvec_test", size = "small", timeout = "long", - srcs = ["ops/matvec_test.cc"], + srcs = ["ops/gemma_matvec_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["hwy_ops_test"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 6bc3d865..a9487833 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,7 @@ set(GEMMA_TEST_FILES backprop/optimize_test.cc ops/ops_test.cc ops/matmul_test.cc - ops/matvec_test.cc + ops/gemma_matvec_test.cc evals/gemma_test.cc ) diff --git a/ops/matvec_test.cc b/ops/gemma_matvec_test.cc similarity index 99% rename from ops/matvec_test.cc rename to ops/gemma_matvec_test.cc index 3e915aff..d9232048 100644 --- a/ops/matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -33,7 +33,7 @@ // clang-format off #undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "ops/matvec_test.cc" // NOLINT +#define HWY_TARGET_INCLUDE "ops/gemma_matvec_test.cc" // NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" From 5c98189415da049977511dcb3eb0f875acbbe611 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Mon, 12 Aug 2024 02:33:10 +0800 Subject: [PATCH 3/3] Fix the position calculation issue in the generation phase --- gemma/gemma-inl.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 96e8a63a..8476b5c1 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, TokenStreamer token_streamer(runtime_config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { gen_tokens[query_idx] = prompts[query_idx][prefill_per_query]; - (void)token_streamer(query_idx_start + query_idx, prefill_per_query, + (void)token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query, gen_tokens[query_idx], 0.0f); } @@ -1020,9 +1021,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const int token = sample_token(logits, kVocabSize); timing_info.NotifyGenerated(prefill_start, gen_start); - const bool is_eos = token_streamer(query_idx_start + query_idx, - prefill_per_query + 1 + gen_per_query, - token, logits[token]); + const bool is_eos = + token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query + 1 + gen_per_query, + token, logits[token]); all_queries_eos &= is_eos; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; }