diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 860e934d..6fbd3f35 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -38,7 +38,10 @@ namespace gcpp { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) - : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { + : initializer_value_(gcpp::InternalInit()), + ctx_(threading), + env_(ctx_), + gemma_(loader, inference, ctx_) { const ModelConfig& config = gemma_.Config(); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 2380dbf7..3f97c21e 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -125,6 +125,8 @@ class GemmaEnv { MatMulEnv& MutableEnv() { return env_; } private: + // This is used to ensure that InternalInit is called before anything else. + int initializer_value_ = 0; ThreadingContext ctx_; MatMulEnv env_; Gemma gemma_; diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 90e46d45..45531ea2 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -153,5 +153,3 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } - - diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 95a0aa7d..abd8c90c 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - gcpp::InternalInit(); gcpp::GemmaTest::InitEnv(argc, argv); int ret = RUN_ALL_TESTS(); gcpp::GemmaTest::DeleteEnv(); diff --git a/gemma/activations.h b/gemma/activations.h index cd1621f9..cfd174f6 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -54,6 +54,11 @@ struct AttentionActivations { ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), + q_bf(MatFactory("q_bf", batch_size, + config.vocab_size == 0 + ? layer_config.heads * 3 * layer_config.qkv_dim + : layer_config.heads * layer_config.qkv_dim, + allocator)), q_T(MatFactory("q_T", layer_config.qkv_dim, config.vocab_size == 0 ? batch_size * layer_config.heads * 3 @@ -88,12 +93,14 @@ struct AttentionActivations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); + q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); + q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); @@ -105,6 +112,7 @@ struct AttentionActivations { } MatStorageT q; // query + MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; @@ -130,6 +138,7 @@ struct AttentionActivationsPtrs { const AttentionActivations& activations) : AttentionActivationsPtrs(config, seq_len) { q = activations.q; + q_bf = activations.q_bf; q_T = activations.q_T; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; @@ -141,6 +150,7 @@ struct AttentionActivationsPtrs { void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); + q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); @@ -151,6 +161,7 @@ struct AttentionActivationsPtrs { const ModelConfig& config; MatPtrT q; + MatPtrT q_bf; MatPtrT q_T; MatPtrT pre_att_rms_out; MatPtrT att; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index d2d13f7c..65485376 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -154,7 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max, // Calculates the complete attention outputs for a single row of q. void SingleFlashAttention(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT q, const MatPtrT& k, + const BF16* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att_out, ThreadingContext& ctx, @@ -162,17 +162,12 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); const hn::ScalableTag dbf; const size_t qkv_dim = k.Cols(); - HWY_ALIGN BF16 q_bf[kMaxQKVDim]; - CompressPerThread tls; - const hn::ScalableTag df; - CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), - 0); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); // TODO: Mixed-mode can be further improved for Turin: we can demote right // before we do the dot product instruction, rather than promote both to f32. // But some potential accuracy loss there, needs evaluation first. - float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim); + float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); if (float cap = activations.config.att_cap; cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. m = cap * std::tanh(m / cap); @@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); - float x = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim); + float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); SingleFlashAttentionStep(x, activations.config.att_cap, m, d, v.Row(pos_mod), v.Cols(), att_out); } @@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, // the dot products of NF rows of Q for a single K timestep. template > VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, - const size_t k_pos, const MatPtrT& q, + const size_t k_pos, const MatPtrT& q, const MatPtrT& k) { const hn::ScalableTag dbf; const size_t qkv_dim = k.Cols(); - HWY_ALIGN BF16 q_bf[kMaxQKVDim]; - CompressPerThread tls; hn::TFromD results[hn::MaxLanes(df)]; for (size_t i = 0; i < hn::Lanes(df); ++i) { - CompressTraits::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls, - MakeSpan(q_bf, qkv_dim), 0); - results[i] = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); + results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0, + k.Row(k_pos), qkv_dim); } return hn::LoadU(df, results); } @@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. void TileFlashAttention( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const StridedView& qT, const MatPtrT& k, const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, @@ -396,7 +386,7 @@ void TileFlashAttention( // This is the result of 4 rows of Q against NF K timesteps, with positions // given by k_offsets[0..NF]. template > -void QDotKTilex4(DF df, const float* HWY_RESTRICT q, +void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { @@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, VI k_offsets_vec = hn::LoadU(di, k_offsets); for (size_t i = 0; i < k.Cols(); ++i) { VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, hwy::ConvertScalarTo( - hwy::ConvertScalarTo(q[q_offsets[0] + i]))); + VF q_0 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[0] + i])); sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, hwy::ConvertScalarTo( - hwy::ConvertScalarTo(q[q_offsets[1] + i]))); + VF q_1 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[1] + i])); sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, hwy::ConvertScalarTo( - hwy::ConvertScalarTo(q[q_offsets[2] + i]))); + VF q_2 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[2] + i])); sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, hwy::ConvertScalarTo( - hwy::ConvertScalarTo(q[q_offsets[3] + i]))); + VF q_3 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[3] + i])); sum3 = hn::MulAdd(q_3, k_vec, sum3); } } @@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. Tile4FlashState TileFlashAttention4( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, @@ -500,18 +486,13 @@ Tile4FlashState TileFlashAttention4( } const hn::ScalableTag dbf; const size_t qkv_dim = k.Cols(); - HWY_ALIGN BF16 q_bf[kMaxQKVDim]; - CompressPerThread tls; - const hn::ScalableTag df_compress; while (position <= max_last_pos) { size_t k_pos = activations.div_seq_len.Remainder(position); if (position <= last_pos[0]) { // Past the last position, x0 doesn't count. - CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[0], - qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); - float x0 = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); + float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0, + k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x0, activations.config.att_cap, state.row_states[0].max, state.row_states[0].d, v.Row(k_pos), v.Cols(), @@ -519,10 +500,8 @@ Tile4FlashState TileFlashAttention4( } if (position <= last_pos[1]) { // Past the last position, x1 doesn't count. - CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[1], - qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); - float x1 = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); + float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0, + k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x1, activations.config.att_cap, state.row_states[1].max, state.row_states[1].d, v.Row(k_pos), v.Cols(), @@ -530,10 +509,8 @@ Tile4FlashState TileFlashAttention4( } if (position <= last_pos[2]) { // Past the last position, x2 doesn't count. - CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[2], - qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); - float x2 = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); + float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0, + k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x2, activations.config.att_cap, state.row_states[2].max, state.row_states[2].d, v.Row(k_pos), v.Cols(), @@ -541,10 +518,8 @@ Tile4FlashState TileFlashAttention4( } if (position <= last_pos[3]) { // Past the last position, x3 doesn't count. - CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[3], - qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); - float x3 = - Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); + float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0, + k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x3, activations.config.att_cap, state.row_states[3].max, state.row_states[3].d, v.Row(k_pos), v.Cols(), @@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, query_norm_scale, layer_idx, activations, ctx); const hwy::Divisor div_qbatch(qbatch.Size()); + // Compress q to q_bf. + ParallelFor( + ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t row, size_t worker) { + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress( + df, activations.q.Row(row), activations.q.Cols(), tls, + MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0); + }); const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const size_t qkv_dim = layer_config.qkv_dim; @@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, last_pos[offset] = last; min_last_pos = HWY_MIN(min_last_pos, last); max_last_pos = HWY_MAX(max_last_pos, last); - q_offsets[offset] = - activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); + q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim - + activations.q_bf.Row(0); out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - activations.att_out.Row(0); const size_t kv_index = head / kHeadGroups; @@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // kNFx8HTileSize. In this case, qT is never used. Some tasks might // use qT and some might not, which is why the more general condition // is used above to catch all cases where qT will be used. - TileFlashAttention(activations.q, q_offsets, qT, k, + TileFlashAttention(activations.q_bf, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); } else if (kVTileSize == 4) { - TileFlashAttention4(activations.q, q_offsets, k, + TileFlashAttention4(activations.q_bf, q_offsets, k, start_positions[offset], last_pos, min_last_pos, max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); @@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, break; } else { SingleFlashAttention(start_positions[offset], last_pos[offset], - activations.q.Row(0) + q_offsets[offset], k, v, + activations.q_bf.Row(0) + q_offsets[offset], k, v, layer_idx, activations, activations.att_out.Row(0) + out_offsets[offset], ctx, worker); diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 099fc697..236c7dc3 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -45,7 +45,7 @@ namespace gcpp { ThreadingContext& ctx, size_t worker); \ \ Tile4FlashState TileFlashAttention4( \ - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ const MatPtrT& k, size_t start_pos, \ const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ diff --git a/io/io.cc b/io/io.cc index bd0d72b0..8114276a 100644 --- a/io/io.cc +++ b/io/io.cc @@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) { return true; } -void InternalInit() { +int InternalInit() { + // currently unused, except for init list ordering in GemmaEnv. + return 0; } uint64_t IOBatch::Read(const File& file) const { diff --git a/io/io.h b/io/io.h index f90a636c..d051715a 100644 --- a/io/io.h +++ b/io/io.h @@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path); // No-op in open-source. Must be called at the beginning of a binary, before // any I/O or flag usage. -void InternalInit(); +int InternalInit(); } // namespace gcpp diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 0a7401a7..1075a0a0 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - gcpp::InternalInit(); gcpp::GemmaEnv env(argc, argv); gcpp::s_env = &env;