Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#include "hwy/base.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
Expand Down Expand Up @@ -58,8 +59,8 @@ size_t FloatsPerVector() {

// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
void MaybeReshapeCache(const size_t default_cols, MatPtrT<KV_t>& cache) {
if (default_cols == cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}
Expand All @@ -71,13 +72,50 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) {
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
if (i + kFloatsPerTile <= qkv_dim) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
} else {
for (size_t j = 0; j < qkv_dim - i; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
}
}
for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim;
i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
}
}

// Zeros out a part of k and v that corresponds to out-of-bounds cache
// positions.
void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v,
size_t qkv_dim) {
const size_t kFloatsPerTile = 2 * FloatsPerVector();
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
for (size_t i = 0; i < kRoundedQkvDim; i += 2) {
k[i * kFloatsPerTile] = hwy::ConvertScalarTo<KV_t>(0.0f);
k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0.0f);
}
}
}
Expand Down Expand Up @@ -314,23 +352,51 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(),
qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();
const size_t kRoundedTokens =
hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector);
const size_t kRoundedNumInterleaved =
kRoundedTokens * div_qbatch.GetDivisor();

// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
ParallelFor(
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
if (token_idx >= kRoundedTokens) {
return;
}
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector,
cache_pos);
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector,
cache_pos);
if (token_idx >= num_tokens) {
// Create a zero-filled K/V pair for padding for out-of-sequence
// tokens.
TransposeOOBKVCacheRow(k, v, qkv_dim);
return;
}
// --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen());
auto& kv_cache = qbatch.KV(qi).kv_cache;
Expand All @@ -341,22 +407,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;

HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
Expand Down
2 changes: 1 addition & 1 deletion gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace gcpp {
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
void MaybeReshapeCache(size_t default_cols, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
Expand Down
3 changes: 3 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
#include "io/fields.h" // IFieldsVisitor
#include "io/io.h" // Path
#include "util/basics.h"
#include "hwy/detect_compiler_arch.h"

namespace gcpp {

constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16);

HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;

HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
Expand Down
22 changes: 13 additions & 9 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,6 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t token_batch = num_tokens * div_qbatch.GetDivisor();
const size_t total_tasks = token_batch * layer_config.heads;
size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks,
Expand All @@ -1716,11 +1715,9 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
params.clear();
for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) {
for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) {
const size_t head_offset = kv_head * qkv_dim * 2;
const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset;
params.push_back(Tile148Params{
.qi_index = qi,
.kv_offset = kv_offset,
.kv_head = kv_head,
});
for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = qbatch.Pos(qi) + batch_idx;
Expand All @@ -1746,7 +1743,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism,
// current tile is full so start new tile.
params.push_back(Tile148Params{
.qi_index = qi,
.kv_offset = kv_offset,
.kv_head = kv_head,
});
}
const size_t head = head_group + kHeadGroups * kv_head;
Expand Down Expand Up @@ -2157,13 +2154,20 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
auto& param = params[task];
auto& kT_cache = qbatch.KV(param.qi_index).k_cache;
const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector);
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
qkv_dim * 2 * kNF));
kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride());
kRoundedQkvDim * 2 * kNF));
kT.SetPtr(
kT_cache.Row(0) + qbatch.KV(param.qi_index)
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
kT_cache.Stride());
auto& vT_cache = qbatch.KV(param.qi_index).v_cache;
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
qkv_dim * 2 * kNF));
vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride());
kRoundedQkvDim * 2 * kNF));
vT.SetPtr(
vT_cache.Row(0) + qbatch.KV(param.qi_index)
.cache->KOrVOffset(layer_idx, param.kv_head, kNF),
vT_cache.Stride());
MatPtrT<float>& att_out =
param.i_of_n == 0 ? activations.att_out : activations.att_out_reps;
DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx,
Expand Down
26 changes: 14 additions & 12 deletions gemma/flash_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ void TestFlashAttention(size_t target_parallelism,
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t seq_len =
static_cast<size_t>(attention.div_seq_len.GetDivisor());
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache);
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
qbatch.KV(0).k_cache);
MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(),
qbatch.KV(0).v_cache);
auto& kvc = qbatch.KV(0).kv_cache;
const size_t kFloatsPerTile = 2 * FloatsPerVector();
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t kFloatsPerTile = 2 * kNF;
for (size_t h = 0; h < layer_config.heads; ++h) {
// Make strided views into the kv cache for
// this query and head.
Expand All @@ -160,12 +165,12 @@ void TestFlashAttention(size_t target_parallelism,
SetMat(h + layer_config.heads * 2, v);
for (size_t p = 0; p < tokens.size(); ++p) {
KV_t* HWY_RESTRICT k_src = k.Row(p);
KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * 2;
KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
head_offset * kFloatsPerTile / 2 +
p % kFloatsPerTile * kFloatsPerTile;
KV_t* HWY_RESTRICT k_dest =
qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) +
qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p);
KV_t* HWY_RESTRICT v_dest =
qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) +
qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p);

TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim);
}
Expand All @@ -176,9 +181,6 @@ void TestFlashAttention(size_t target_parallelism,
// Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q);
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t total_tasks =
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
Expand Down
4 changes: 2 additions & 2 deletions gemma/flash_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ struct Tile148Params {
uint32_t max_last_pos = 0;
// Index into the qbatch.KV is the same for each row in the tile.
uint32_t qi_index;
// Index into the kv_cache is the same for each row in the tile.
uint32_t kv_offset;
// kv_head is the same for each row in the tile.
uint32_t kv_head;
// In the original task, the index to the split tasks of the first split task.
uint32_t split_index = 0;
// The index of the split for running split attention.
Expand Down
29 changes: 14 additions & 15 deletions gemma/kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@

namespace gcpp {

// TODO: rays - Remove this once hwy is updated.
#ifndef HWY_ARCH_MAX_BYTES
#define HWY_ARCH_MAX_BYTES 256
#endif

// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
Expand All @@ -46,8 +41,13 @@ static size_t CappedSeqLen(const ModelConfig& config,
return inference_args.seq_len;
}

KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers,
size_t kv_heads, size_t qkv_dim, const Allocator& allocator)
: num_layers(num_layers),
kv_heads(kv_heads),
qkv_dim(qkv_dim),
rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
// WARNING: the rows and cols of k_cache and v_cache will be modified
// before use!
// The rows will be reduced by a factor of 2xkFloatsPerVector, and the
Expand All @@ -56,22 +56,21 @@ KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
// machine architecture, since kFloatsPerVector is architecture dependent.
// The change is shape is safe only if the padding is kPacked.
k_cache("k",
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
KOrVDefaultCols()),
allocator, MatPadding::kPacked),
v_cache("v",
Extents2D(HWY_MAX(kv_extents.rows,
2 * HWY_ARCH_MAX_BYTES / sizeof(float)),
kv_extents.cols / 2),
Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector),
KOrVDefaultCols()),
allocator, MatPadding::kPacked),
allocator_(allocator) {}

KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator)
: KVCache(
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator) {}
config.layer_configs.size(), config.layer_configs[0].kv_heads,
config.layer_configs[0].qkv_dim, allocator) {}

KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const RuntimeConfig& runtime_config,
Expand Down Expand Up @@ -135,7 +134,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
}

KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_);

CopyMat(kv_cache, copy.kv_cache);
return copy;
Expand Down
35 changes: 34 additions & 1 deletion gemma/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,38 @@ struct KVCache {
return {start_ptr, source_ptr};
}

// Returns the default size of a row in k_cache or v_cache, before scaling by
// 2 * kNF.
size_t KOrVDefaultCols() const {
return num_layers * kv_heads * rounded_qkv_dim;
}

// Returns an offset into a row of k_cache or v_cache at a position that is
// aligned to the tile size (a multiple of 2kNF).
size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF) const {
return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF;
}

// Returns an offset into k_cache at any given position.
size_t KOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF, const size_t pos) const {
return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2;
}

// Returns an offset into v_cache at any given position.
size_t VOffset(const size_t layer_idx, const size_t kv_head_idx,
const size_t kNF, const size_t pos) const {
return KOrVOffset(layer_idx, kv_head_idx, kNF) +
(pos % (2 * kNF)) * 2 * kNF;
}

// Saved sizes for computing offsets into the KV cache.
size_t num_layers = 0;
size_t kv_heads = 0;
size_t qkv_dim = 0;
size_t rounded_qkv_dim = 0;

static constexpr size_t kTileSize = 32;
std::optional<uint32_t> tiled_seq_len = std::nullopt;
// Default Format
Expand Down Expand Up @@ -159,7 +191,8 @@ struct KVCache {
const Allocator& allocator_;

// For use by other ctor and Copy()
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads,
size_t qkv_dim, const Allocator& allocator);
};

inline size_t KVCachePtr::SeqLen() const {
Expand Down
Loading