Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ struct AttentionActivations {
// `inv_timescale*` are not batched.
}

MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<float> q; // query
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.

MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
Expand Down Expand Up @@ -151,7 +151,7 @@ struct AttentionActivationsPtrs {

const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;
Expand Down
15 changes: 13 additions & 2 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,27 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];

CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);

if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = Dot(q, k.Row(pos), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
att[pos_modulo] = score;
}
}
Expand Down
120 changes: 81 additions & 39 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static constexpr size_t kNFx8HTileSize = 8;
// q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
Expand All @@ -69,12 +69,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
hwy::ConvertScalarTo<BF16>(
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
}
}
}
Expand Down Expand Up @@ -158,8 +159,19 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];

CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
// TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
Expand All @@ -169,7 +181,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
float x =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
Expand All @@ -179,25 +192,31 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
// the dot products of NF rows of Q for a single K timestep.
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q,
const size_t k_pos, const MatPtrT<float>& q,
const MatPtrT<KV_t>& k) {
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;

hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
MakeSpan(q_bf, qkv_dim), 0);
results[i] =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
}
return hn::LoadU(df, results);
}

// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single
// precision.
// Returns an NF Q rows by 8 K rows tile of Q.K dot products.
// This is the result of NF rows of Q against 8 K timesteps, with positions
// given by k_pos[0..7]. Q has been transposed so that the NF rows are read in
// consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
Expand All @@ -211,8 +230,13 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
for (int i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]);
}

const hn::Rebind<BF16, DF> dbfh;
using VBF = hn::Vec<decltype(dbfh)>;

for (size_t i = 0; i < k.Cols(); ++i) {
VF q_vec = hn::Load(df, q);
const VBF q_vec_bf = hn::Load(dbfh, q);
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]);
Expand Down Expand Up @@ -264,17 +288,14 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
Expand All @@ -291,7 +312,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0);
const BF16* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
Expand All @@ -300,8 +321,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x7);
QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
Expand Down Expand Up @@ -390,13 +410,17 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
VI k_offsets_vec = hn::LoadU(di, k_offsets);
for (size_t i = 0; i < k.Cols(); ++i) {
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
VF q_0 = hn::Set(df, q[q_offsets[0] + i]);
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
sum0 = hn::MulAdd(q_0, k_vec, sum0);
VF q_1 = hn::Set(df, q[q_offsets[1] + i]);
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
sum1 = hn::MulAdd(q_1, k_vec, sum1);
VF q_2 = hn::Set(df, q[q_offsets[2] + i]);
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
sum2 = hn::MulAdd(q_2, k_vec, sum2);
VF q_3 = hn::Set(df, q[q_offsets[3] + i]);
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
sum3 = hn::MulAdd(q_3, k_vec, sum3);
}
}
Expand Down Expand Up @@ -478,32 +502,50 @@ void TileFlashAttention4(const MatPtrT<float>& q,
out_offsets, v.Cols());
position += kHTileSize;
}
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
CompressPerThread tls;
const hn::ScalableTag<float> df_compress;

while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
if (position <= last_pos[0]) {
// Past the last position, x0 doesn't count.
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x0 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0]);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x1 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1]);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x2 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2]);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
float x3 =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3]);
Expand Down Expand Up @@ -722,9 +764,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once.
StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
StridedView<BF16> qT =
StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q
// above. The condition used for transposing Q above is more general
Expand Down
3 changes: 2 additions & 1 deletion ops/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ using DotKernelDefault =
template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
DotKernelDefault());
}

// Adapter for two pointers, no bounds checking.
Expand Down
Loading