Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ namespace gcpp {

GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
: initializer_value_(gcpp::InternalInit()),
ctx_(threading),
env_(ctx_),
gemma_(loader, inference, ctx_) {
const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
Expand Down
2 changes: 2 additions & 0 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class GemmaEnv {
MatMulEnv& MutableEnv() { return env_; }

private:
// This is used to ensure that InternalInit is called before anything else.
int initializer_value_ = 0;
ThreadingContext ctx_;
MatMulEnv env_;
Gemma gemma_;
Expand Down
2 changes: 0 additions & 2 deletions evals/gemma_batch_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,3 @@ int main(int argc, char** argv) {

return RUN_ALL_TESTS();
}


1 change: 0 additions & 1 deletion evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();
Expand Down
11 changes: 11 additions & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ struct AttentionActivations {
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_bf(MatFactory("q_bf", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim,
allocator)),
q_T(MatFactory("q_T", layer_config.qkv_dim,
config.vocab_size == 0
? batch_size * layer_config.heads * 3
Expand Down Expand Up @@ -88,12 +93,14 @@ struct AttentionActivations {
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}

void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!

pre_att_rms_out.OverrideRows(batch_size);
Expand All @@ -105,6 +112,7 @@ struct AttentionActivations {
}

MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.

MatStorageT<float> pre_att_rms_out;
Expand All @@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
Expand All @@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {

void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
Expand All @@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {

const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<BF16> q_bf;
MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
Expand Down
88 changes: 37 additions & 51 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,

// Calculates the complete attention outputs for a single row of q.
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];

CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
// TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
Expand All @@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
Expand All @@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<float>& q,
const size_t k_pos, const MatPtrT<BF16>& q,
const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;

hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
MakeSpan(q_bf, qkv_dim), 0);
results[i] =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
}
return hn::LoadU(df, results);
}
Expand Down Expand Up @@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
Expand Down Expand Up @@ -396,7 +386,7 @@ void TileFlashAttention(
// This is the result of 4 rows of Q against NF K timesteps, with positions
// given by k_offsets[0..NF].
template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
Expand All @@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
sum3 = hn::MulAdd(q_3, k_vec, sum3);
}
}
Expand All @@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
Tile4FlashState TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
Expand Down Expand Up @@ -500,51 +486,40 @@ Tile4FlashState TileFlashAttention4(
}
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df_compress;

while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap,
state.row_states[0].max, state.row_states[0].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap,
state.row_states[1].max, state.row_states[1].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap,
state.row_states[2].max, state.row_states[2].d,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap,
state.row_states[3].max, state.row_states[3].d,
v.Row(k_pos), v.Cols(),
Expand Down Expand Up @@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
// Compress q to q_bf.
ParallelFor(
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) {
CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(
df, activations.q.Row(row), activations.q.Cols(), tls,
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
});
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;

Expand Down Expand Up @@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
last_pos[offset] = last;
min_last_pos = HWY_MIN(min_last_pos, last);
max_last_pos = HWY_MAX(max_last_pos, last);
q_offsets[offset] =
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
activations.q_bf.Row(0);
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
activations.att_out.Row(0);
const size_t kv_index = head / kHeadGroups;
Expand Down Expand Up @@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
// use qT and some might not, which is why the more general condition
// is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k,
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k,
TileFlashAttention4(activations.q_bf, q_offsets, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
Expand All @@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
break;
} else {
SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v,
activations.q_bf.Row(0) + q_offsets[offset], k, v,
layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx, worker);
Expand Down
2 changes: 1 addition & 1 deletion gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace gcpp {
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
Expand Down
4 changes: 3 additions & 1 deletion io/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
return true;
}

void InternalInit() {
int InternalInit() {
// currently unused, except for init list ordering in GemmaEnv.
return 0;
}

uint64_t IOBatch::Read(const File& file) const {
Expand Down
2 changes: 1 addition & 1 deletion io/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);

// No-op in open-source. Must be called at the beginning of a binary, before
// any I/O or flag usage.
void InternalInit();
int InternalInit();

} // namespace gcpp

Expand Down
1 change: 0 additions & 1 deletion paligemma/paligemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();

gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
Expand Down
Loading