Skip to content

Commit

Permalink
[LLM Runtime] Beam Search Support of Fused Attention (#734)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Nov 27, 2023
1 parent 8188e6a commit ae95a29
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 15 deletions.
6 changes: 4 additions & 2 deletions intel_extension_for_transformers/llm/runtime/graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Argument description of WeightOnlyQuantConfig:
| weight_dtype | String | Data type of quantized weight: int4/int8 (default int4) |
| alg | String | Quantization algorithm: sym/asym (default sym) |
| group_size | Int | Group size: Int (default: 32) |
| scale_dtype | String | Data type of scales: fp32/bf16 (dafault fp32) |
| scale_dtype | String | Data type of scales: fp32/bf16 (default fp32) |
| use_ggml | Bool | Enable ggml for quantization and inference (default: False) |
| use_quant | Bool | Determine whether or not the model will be quantized. (default: True) |
| use_cache | Bool | Use local quantized model if file exists (default: False) |
Expand All @@ -125,7 +125,8 @@ Argument description of generate function:
| batch_size | Int | Batch size for prompt processing (default: 512) |
| ctx_size | Int | Size of the prompt context (default: 512) |
| seed | Int | NG seed (default: -1, use random seed for < 0) |
| threads | Int | Number of threads to use during computation (default: 8) |
| threads | Int | Number of threads to use during computation (default: min(available_core_num, OMP_NUM_THREADS)) |
| memory_dtype | str | Data type of the KV memory; one of f16, f32, auto (enables Fused Attention when possible otherwise fallback to f16) (default: auto) |
| repetition_penalty| Float | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| num_beams | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| do_sample | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
Expand All @@ -138,6 +139,7 @@ Argument description of generate function:
| max_new_tokens | Int | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| streamer | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| stopping_criteria | Class | Please refer to [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |
| pad_token | Int | pad_token_id of [Transformer's generate](https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/text_generation#generation) |

### 3. Multi-Round Chat

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_
params.beam_size = num_beams;
params.do_sample = do_sample;
params.batch_size = batch_size;
params.beam_search = (num_beams > 1 && !do_sample) ? true : false;
params.beam_search = (num_beams > 1 && !do_sample);
params.top_k = top_k;
params.top_p = top_p;
params.temp = temperature;
Expand All @@ -171,7 +171,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_
params.memory_type = KV_MEM_TYPE_AUTO;
else
fprintf(stderr, "Unexpected memory dtype!");
if (params.beam_search) params.memory_type = KV_MEM_TYPE_F16; // TODO(Yi): NO MHA IN BEAM SEARCH
if (batch_size > 1) params.memory_type = KV_MEM_TYPE_F16; // TODO(Yi): NO MHA IN MULTI-BATCH

printf("beam_size: %d, do_sample: %d, top_k: %d, top_p: %f\n", params.beam_size, params.do_sample, params.top_k,
params.top_p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void* init_gptj(int seed, int n_predict, int n_batch, int top_k, float top_p, fl
params.batch_size = batch_size;
params.beam_search = beam_search;
params.beam_size = beam_size;
params.memory_type = KV_MEM_TYPE_F16; // TODO MEMORY_AUTO for MHA
if (batch_size > 1) params.memory_type = KV_MEM_TYPE_F16; // TODO(Yi): NO MHA IN MULTI-BATCH
// params.use_mmap = false;
// params.use_mlock= true;
model_init_backend();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,84 @@ void jblas_reordered_attn_fp32_shift_rope_k(char* cache, const ne_fp16_t* cossin
}
}

template <bool zero_padding>
void jblas_fusion_attn_fp32_batch_cpy_k_(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params) {
static constexpr auto N_TILE = 48;
static constexpr auto K_TILE = 32;
static constexpr auto K_PACK = 2;
const auto p = *params;
const auto pad_headsize = padto(p.head_size, K_TILE);
const auto pad_seq_max = padto(p.seq_max, N_TILE);
const auto step_head_num = pad_headsize * pad_seq_max;

const auto seq_unaligned = std::min(padto(p.seq_off, N_TILE) - p.seq_off, p.seq_size);
const auto size_aligned_cpy = pad_headsize * (padto(p.seq_off + p.seq_size, N_TILE) - padto(p.seq_off, N_TILE));
#pragma omp parallel for
for (int ihn = 0; ihn < p.heads_kv; ++ihn) {
const auto dst = reinterpret_cast<bf16*>(p.dst) + ihn * step_head_num;
const auto src = reinterpret_cast<bf16*>(p.src) + ihn * step_head_num;

if (seq_unaligned) {
const auto ii = p.seq_off % N_TILE;
const auto i_blk = p.seq_off - ii;
const auto off = i_blk * pad_headsize + ii * K_PACK;
for (int j = 0; j < pad_headsize; j += K_PACK) { // K-dim padding for QK_GEMM
memcpy(dst + off + j * N_TILE, src + off + j * N_TILE, sizeof(bf16) * K_PACK * seq_unaligned);
}
}
if constexpr (zero_padding) {
if (size_aligned_cpy) {
const auto off = padto(p.seq_off, N_TILE) * pad_headsize;
memcpy(dst + off, src + off, sizeof(bf16) * size_aligned_cpy);
}
} else {
assert(("Unimplemented!", false));
}
}
}
void jblas_fusion_attn_fp32_batch_cpy_k(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params) {
return params->no_zeroing ? jblas_fusion_attn_fp32_batch_cpy_k_<false>(params)
: jblas_fusion_attn_fp32_batch_cpy_k_<true>(params);
}

template <bool zero_padding>
void jblas_fusion_attn_fp32_batch_cpy_v_(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params) {
static constexpr auto N_TILE = 48;
static constexpr auto K_TILE = 32;
static constexpr auto K_PACK = 2;
const auto p = *params;
const auto pad_headsize = padto(p.head_size, N_TILE);
const auto pad_seq_max = padto(p.seq_max, K_TILE);
const auto step_head_num = pad_headsize * pad_seq_max;

const auto seq_off_aligned = padto(p.seq_off, K_PACK);
const auto seq_end_aligned = padto(p.seq_off + p.seq_size, K_TILE);
const auto seq_size_aligned = seq_end_aligned - seq_off_aligned;
#pragma omp parallel for collapse(2)
for (int ihn = 0; ihn < p.heads_kv; ++ihn) {
for (int j = 0; j < p.head_size; j += N_TILE) {
const auto dst = reinterpret_cast<bf16*>(p.dst) + ihn * step_head_num + pad_seq_max * j;
const auto src = reinterpret_cast<bf16*>(p.src) + ihn * step_head_num + pad_seq_max * j;
if (p.seq_off != seq_off_aligned) { // seq_size_unaligen must be 0 or 1 as K_PACK = 2
const auto off = (seq_off_aligned - K_PACK) * N_TILE + 1;
for (int jj = 0; jj < N_TILE; ++jj) dst[off + jj * K_PACK] = src[off + jj * K_PACK];
}
if constexpr (zero_padding) {
if (seq_off_aligned != seq_end_aligned) {
const auto off = seq_off_aligned * N_TILE;
memcpy(dst + off, src + off, sizeof(bf16) * N_TILE * seq_size_aligned);
}
} else {
assert(("Unimplemented!", false));
}
}
}
}
void jblas_fusion_attn_fp32_batch_cpy_v(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params) {
return params->no_zeroing ? jblas_fusion_attn_fp32_batch_cpy_v_<false>(params)
: jblas_fusion_attn_fp32_batch_cpy_v_<true>(params);
}

#ifdef __GNUC__
#pragma GCC pop_options
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ void jblas_reordered_attn_fp32_update_v(const jblas_fusion_attn_fp32_update_kv_a
void jblas_reordered_attn_fp32_shift_rope_k(char* cache, const ne_fp16_t* cossin, int batch_size, int heads_kv,
int head_size, int seq_max, int seq_keep);

typedef struct jblas_fusion_attn_fp32_batch_cpy_kv_args_t {
char* src;
char* dst;
int heads_kv, head_size, seq_off, seq_size, seq_max;
bool no_zeroing; // set to true to prevent zeroing unaligned seq
} jblas_fusion_attn_fp32_batch_cpy_kv_args_t;
// copy k-cache across batch from seq_off to (seq_off + seq_size)
void jblas_fusion_attn_fp32_batch_cpy_k(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params);
// copy v-cache across batch from seq_off to (seq_off + seq_size)
void jblas_fusion_attn_fp32_batch_cpy_v(const jblas_fusion_attn_fp32_batch_cpy_kv_args_t* params);

typedef struct jblas_reordered_attn_fp32_fp32_fwd_args_t {
float* Q;
char* K; // K/V should be of type and layout used in corrsponding jblas_reordered_attn_xxx_update_kv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) {
// printf("n_head_kv=%s,multi_query_group_num=%s",model_hparams.n_head_kv,model_hparams.multi_query_group_num);
NE_ASSERT(("Can not set n_head_kv and multi_query_group_num at the same time",
model_hparams.n_head_kv == 0 || model_hparams.multi_query_group_num == 0 ||
model_hparams.n_head_kv != model_hparams.multi_query_group_num));
model_hparams.n_head_kv == model_hparams.multi_query_group_num));
attn_shape_t attn_shape = {
/* .batch_size = */ lparams.batch_size * lparams.beam_size,
/* .head_num = */ static_cast<int>(model_hparams.n_head),
Expand Down Expand Up @@ -2080,14 +2080,54 @@ static void ne_model_kv_cache_seq_cpy(struct model_context* ctx, const model_seq
}
}

static void jblas_model_kv_cache_seq_cpy(struct model_context* ctx, const model_seq_id& seq_id_src,
const model_seq_id& seq_id_dst, const model_pos& p0, const model_pos& p1) {
const auto& kv_self = ctx->model.kv_self;
const auto& hparams = ctx->model.hparams;
const int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head;
const int head_size = hparams.n_embd / hparams.n_head;
const int n_ctx = ctx->n_ctx;
const auto kv_n_ctx_block = ctx->kv_n_ctx_block;
NE_ASSERT(("Invalid end position!", n_ctx >= p1));
kv_cache_info_t kv_cache_info;
kv_shape_t kv_shape{
/* .head_num = */ static_cast<uint32_t>(heads_kv),
/* .head_size = */ static_cast<uint32_t>(head_size),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
const auto k_bytes = kv_cache_info.k_bytes;
const auto v_bytes = kv_cache_info.v_bytes;

jblas_fusion_attn_fp32_batch_cpy_kv_args_t seq_cpy_param{
/* .src = */ nullptr,
/* .dst = */ nullptr,
/* .heads_kv = */ heads_kv,
/* .head_size = */ head_size,
/* .seq_off = */ p0,
/* .seq_size = */ p1 - p0,
/* .seq_max = */ n_ctx,
/* .no_zeroing = */ false,
};
for (int il = 0; il < ctx->model.layers.size(); ++il) {
const auto k_data = reinterpret_cast<char*>(kv_self.k->data) + il * kv_n_ctx_block * k_bytes;
seq_cpy_param.src = k_data + seq_id_src * k_bytes;
seq_cpy_param.dst = k_data + seq_id_dst * k_bytes;
jblas_fusion_attn_fp32_batch_cpy_k(&seq_cpy_param);

const auto v_data = reinterpret_cast<char*>(kv_self.v->data) + il * kv_n_ctx_block * v_bytes;
seq_cpy_param.src = v_data + seq_id_src * v_bytes;
seq_cpy_param.dst = v_data + seq_id_dst * v_bytes;
jblas_fusion_attn_fp32_batch_cpy_v(&seq_cpy_param);
}
}

void model_kv_cache_seq_cpy(struct model_context* ctx, const model_seq_id& seq_id_src, const model_seq_id& seq_id_dst,
const model_pos& p0, const model_pos& p1) {
if (ctx->model.kv_self.k->type != NE_TYPE_JBLAS) {
if (ctx->model.kv_self.k->type != NE_TYPE_JBLAS)
ne_model_kv_cache_seq_cpy(ctx, seq_id_src, seq_id_dst, p0, p1);
} else {
return;
// jblas_model_kv_cache_seq_cpy(ctx, seq_id_src, seq_id_dst, p0, p1);
}
else
jblas_model_kv_cache_seq_cpy(ctx, seq_id_src, seq_id_dst, p0, p1);
}

static ne_tensor* ne_model_kv_cache_seq_concat(struct ne_cgraph* cgraph, struct model_context* moctx,
Expand Down Expand Up @@ -2270,10 +2310,7 @@ void beam_search_kv_cache_reorder::update(const std::vector<uint32_t>& n_past,
const std::vector<std::tuple<int, int>>& kv_reorder_indices,
const std::vector<beam>& next_beams) {
// TODO beam search unsupport shift kv cache when prompt + new_tokens > nctx
if (ctx->model.kv_self.has_shift) {
fprintf(stderr, "%s: error: unimplement shifted kv cache update\n", __func__);
return;
}
NE_ASSERT(("error: unimplement shifted kv cache update\n", !ctx->model.kv_self.has_shift));
#ifdef NE_BEAM_SEARCH_VERBOSE_ON
printf("start to update kv cache for next step...\n");
#endif
Expand Down

0 comments on commit ae95a29

Please sign in to comment.