Skip to content

Commit

Permalink
[API] Add LLaMA decoder API. (#386)
Browse files Browse the repository at this point in the history
* [API] Add LLaMA decoder API.

* float16_t kv cache

* Add activation type.

* Add norm type.

* comments
  • Loading branch information
changqi1 committed May 13, 2024
1 parent cf81b38 commit bff98bf
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 21 deletions.
13 changes: 13 additions & 0 deletions include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,17 @@ enum DeviceKind {
iCPU = 0,
iGPU,
};

enum NormType {
RMS = 0,
LN,
};

enum ActivationType {
RELU = 0,
GELU,
SWIGLU,
SILU,
};

} // namespace xft
29 changes: 29 additions & 0 deletions include/layers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

#include "dtype.h"

namespace xft {

void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim,
int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride,
const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta,
const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr,
const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr);

} // namespace xft
4 changes: 2 additions & 2 deletions include/layers_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void invokeAttention(DataType dt,
void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum,
int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output,
int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const void *queryBias = nullptr,
const void *keyBias = nullptr, const void *valueBias = nullptr, const void *attnOutBias = nullptr);
const void *valueWeight, const void *attnOutWeight, const float *queryBias = nullptr,
const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr);

} // namespace xft
40 changes: 21 additions & 19 deletions src/layers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace xft {
void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum,
int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output,
int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight, const void *queryBias, const void *keyBias,
const void *valueBias, const void *attnOutBias) {
const void *valueWeight, const void *attnOutWeight, const float *queryBias, const float *keyBias,
const float *valueBias, const float *attnOutBias) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

Expand Down Expand Up @@ -81,15 +81,16 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe
llama_attention_hub;

static DecoderContext *ctx;
static KVCacheManager<float> *kvCacheMgr;
static KVCacheManager<float16_t> *kvCacheMgr;

if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, attHeadDim);
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0,
maxPositions, maxPosEmbed, -1, 0, 1);
ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
kvCacheMgr = new KVCacheManager<float>(1);
if (kvCacheMgr != nullptr) delete kvCacheMgr;
kvCacheMgr = new KVCacheManager<float16_t>(1);
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
Expand All @@ -102,10 +103,10 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe
auto it_created = llama_attention_hub.find(llama_attention_key);
if (it_created == llama_attention_hub.end()) {
llama_attention = new Attention<bfloat16_t, LlamaRotaryEmbedding, RmsNorm>(0, ctx);
llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias,
(float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr,
(float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr,
nullptr, false);
llama_attention->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias,
(const float *)keyWeight, nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr,
valueBias, (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, false, nullptr, nullptr,
false);
llama_attention_hub[llama_attention_key] = llama_attention;
printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str());
} else {
Expand All @@ -120,8 +121,8 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe
int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
KVCacheTensor<float> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float> &presentValue = kvCacheMgr->getValue(0);
KVCacheTensor<float16_t> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float16_t> &presentValue = kvCacheMgr->getValue(0);

llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey,
presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr);
Expand All @@ -130,15 +131,16 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe
llama_attention_hub;

static DecoderContext *ctx;
static KVCacheManager<float> *kvCacheMgr;
static KVCacheManager<float16_t> *kvCacheMgr;

if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, attHeadDim);
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0,
maxPositions, maxPosEmbed, -1, 0, 1);
ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
kvCacheMgr = new KVCacheManager<float>(1);
if (kvCacheMgr != nullptr) delete kvCacheMgr;
kvCacheMgr = new KVCacheManager<float16_t>(1);
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
Expand All @@ -151,10 +153,10 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe
auto it_created = llama_attention_hub.find(llama_attention_key);
if (it_created == llama_attention_hub.end()) {
llama_attention = new Attention<float16_t, LlamaRotaryEmbedding, RmsNorm>(0, ctx);
llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias,
(float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr,
(float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr,
nullptr, false);
llama_attention->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias,
(const float *)keyWeight, nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr,
valueBias, (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, false, nullptr, nullptr,
false);
llama_attention_hub[llama_attention_key] = llama_attention;
printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str());
} else {
Expand All @@ -163,14 +165,14 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe

ctx->resize(batchSize, inputSeqLen, pastSeqLen);
hpj::Matrix<float> actBuffers;
actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize);
actBuffers.Resize(batchSize * inputSeqLen, hiddenSize);
float *attnMask = prepareAttnMask(ctx, step);

int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
KVCacheTensor<float> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float> &presentValue = kvCacheMgr->getValue(0);
KVCacheTensor<float16_t> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float16_t> &presentValue = kvCacheMgr->getValue(0);

llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey,
presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr);
Expand Down
Loading

0 comments on commit bff98bf

Please sign in to comment.