diff --git a/include/dtype.h b/include/dtype.h index 5fbf9c03..de72bcee 100644 --- a/include/dtype.h +++ b/include/dtype.h @@ -38,4 +38,17 @@ enum DeviceKind { iCPU = 0, iGPU, }; + +enum NormType { + RMS = 0, + LN, +}; + +enum ActivationType { + RELU = 0, + GELU, + SWIGLU, + SILU, +}; + } // namespace xft diff --git a/include/layers.h b/include/layers.h new file mode 100644 index 00000000..34f6aa52 --- /dev/null +++ b/include/layers.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include "dtype.h" + +namespace xft { + +void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, + int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, + int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, + const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, + const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr, + const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr); + +} // namespace xft \ No newline at end of file diff --git a/include/layers_attention.h b/include/layers_attention.h index 7752e18e..100cde81 100644 --- a/include/layers_attention.h +++ b/include/layers_attention.h @@ -61,7 +61,7 @@ void invokeAttention(DataType dt, void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output, int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight, - const void *valueWeight, const void *attnOutWeight, const void *queryBias = nullptr, - const void *keyBias = nullptr, const void *valueBias = nullptr, const void *attnOutBias = nullptr); + const void *valueWeight, const void *attnOutWeight, const float *queryBias = nullptr, + const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr); } // namespace xft \ No newline at end of file diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp index 9bef638e..48ae9b2d 100644 --- a/src/layers/attention.cpp +++ b/src/layers/attention.cpp @@ -24,8 +24,8 @@ namespace xft { void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output, int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight, - const void *valueWeight, const void *attnOutWeight, const void *queryBias, const void *keyBias, - const void *valueBias, const void *attnOutBias) { + const void *valueWeight, const void *attnOutWeight, const float *queryBias, const float *keyBias, + const float *valueBias, const float *attnOutBias) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -81,7 +81,7 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe llama_attention_hub; static DecoderContext *ctx; - static KVCacheManager *kvCacheMgr; + static KVCacheManager *kvCacheMgr; if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { if (ctx != nullptr) delete ctx; @@ -89,7 +89,8 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, maxPositions, maxPosEmbed, -1, 0, 1); ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); - kvCacheMgr = new KVCacheManager(1); + if (kvCacheMgr != nullptr) delete kvCacheMgr; + kvCacheMgr = new KVCacheManager(1); } // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. @@ -102,10 +103,10 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe auto it_created = llama_attention_hub.find(llama_attention_key); if (it_created == llama_attention_hub.end()) { llama_attention = new Attention(0, ctx); - llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias, - (float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr, - (float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr, - nullptr, false); + llama_attention->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias, + (const float *)keyWeight, nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr, + valueBias, (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, false, nullptr, nullptr, + false); llama_attention_hub[llama_attention_key] = llama_attention; printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str()); } else { @@ -120,8 +121,8 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe int workers = 1; int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); - KVCacheTensor &presentKey = kvCacheMgr->getKey(0); - KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey, presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr); @@ -130,7 +131,7 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe llama_attention_hub; static DecoderContext *ctx; - static KVCacheManager *kvCacheMgr; + static KVCacheManager *kvCacheMgr; if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { if (ctx != nullptr) delete ctx; @@ -138,7 +139,8 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, maxPositions, maxPosEmbed, -1, 0, 1); ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); - kvCacheMgr = new KVCacheManager(1); + if (kvCacheMgr != nullptr) delete kvCacheMgr; + kvCacheMgr = new KVCacheManager(1); } // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. @@ -151,10 +153,10 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe auto it_created = llama_attention_hub.find(llama_attention_key); if (it_created == llama_attention_hub.end()) { llama_attention = new Attention(0, ctx); - llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias, - (float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr, - (float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr, - nullptr, false); + llama_attention->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias, + (const float *)keyWeight, nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr, + valueBias, (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, false, nullptr, nullptr, + false); llama_attention_hub[llama_attention_key] = llama_attention; printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str()); } else { @@ -163,14 +165,14 @@ void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHe ctx->resize(batchSize, inputSeqLen, pastSeqLen); hpj::Matrix actBuffers; - actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); + actBuffers.Resize(batchSize * inputSeqLen, hiddenSize); float *attnMask = prepareAttnMask(ctx, step); int workers = 1; int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); - KVCacheTensor &presentKey = kvCacheMgr->getKey(0); - KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey, presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr); diff --git a/src/layers/decoder_layer.cpp b/src/layers/decoder_layer.cpp new file mode 100644 index 00000000..6a7c798f --- /dev/null +++ b/src/layers/decoder_layer.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "decoder_layer.h" +#include "attention.h" +#include "kvcache_manager.h" +#include "layer_norm.h" +#include "layers_attention.h" +#include "layers_mlp.h" +#include "mlp_llama.h" +#include "rms_norm.h" + +#include + +namespace xft { + +template +void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, + int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, + int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, + const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, + const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr, + const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr) { + + // TODO: will deprecate attention mask in future, so need to change this + auto prepareAttnMask = [&](DecoderContext *ctx, int step) { + int seqLen = ctx->inputSeqLen; + int accSeqLen = pastSeqLen + currentSeqLen; + float *mask = nullptr; + + auto getAttnMask = [](int sizeRequired) { + static float *attnMask; + static int maskSize = 0; + if (maskSize < sizeRequired) { + if (attnMask) free(attnMask); + attnMask = (float *)xft::alloc(sizeRequired * sizeof(float)); + maskSize = sizeRequired; + } + return attnMask; + }; + + if (step == 0) { + int sizeRequired = ctx->batchSize * seqLen * seqLen; + mask = getAttnMask(sizeRequired); + for (int b = 0; b < ctx->batchSize; ++b) { + auto pmask = mask + b * seqLen * seqLen; + for (int i = 0; i < seqLen; ++i) { + memset(pmask + i * seqLen, 0, (i + 1) * sizeof(float)); // bottom left are 0 + std::fill_n(pmask + i * seqLen + i + 1, seqLen - i - 1, std::numeric_limits::lowest()); + } + } + } else if (seqLen > 1) { + int sizeRequired = ctx->batchSize * accSeqLen * seqLen; + mask = getAttnMask(sizeRequired); + for (int b = 0; b < ctx->batchSize; ++b) { + auto pmask = mask + b * accSeqLen * seqLen; + int pastLen = accSeqLen - seqLen; + for (int i = 0; i < seqLen; ++i) { + memset(pmask + i * accSeqLen, 0, (pastLen + i + 1) * sizeof(float)); + std::fill_n(pmask + i * accSeqLen + pastLen + i + 1, seqLen - i - 1, + std::numeric_limits::lowest()); + } + } + } else { + int sizeRequired = ctx->batchSize * accSeqLen; + mask = getAttnMask(sizeRequired); + memset(mask, 0, ctx->batchSize * accSeqLen * sizeof(float)); // all elements are 0 + } + + return mask; + }; + + using DECODER = Decoder, LlamaMLP>; + static std::unordered_map llama_layer_hub; + static DecoderContext *ctx; + static KVCacheManager *kvCacheMgr; + + std::string actType; + if (at == ActivationType::SILU) + actType = "silu"; + else if (at == ActivationType::RELU) + actType = "relu"; + else if (at == ActivationType::GELU) + actType = "gelu"; + else if (at == ActivationType::SWIGLU) + actType = "swiglu"; + else + printf(">> unsupported activation type\n"); + + if (ctx == nullptr + || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { + if (ctx != nullptr) delete ctx; + printf(">> create context: %d %d\n", hiddenSize, intermediateSize); + ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0, + 0, maxPositions, maxPosEmbed, -1, 0, 1); + ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + if (kvCacheMgr != nullptr) delete kvCacheMgr; + kvCacheMgr = new KVCacheManager(1); + } + + // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. + std::stringstream weights_addr; + weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << gateWeight + << "_" << upWeight << "_" << downWeight << "_" << dt << "_" << at << "_" << nt << "_" << attHeadDim + << "_" << attHeadNum << "_" << kvHeadNum; + std::string llama_layer_key = weights_addr.str(); + DECODER *llama_layer; + + auto it_created = llama_layer_hub.find(llama_layer_key); + if (it_created == llama_layer_hub.end()) { + llama_layer = new DECODER(ctx, 0); + llama_layer->setWeights(ctx, (const float *)queryWeight, nullptr, nullptr, queryBias, (const float *)keyWeight, + nullptr, nullptr, keyBias, (const float *)valueWeight, nullptr, nullptr, valueBias, + (const float *)attnOutWeight, nullptr, nullptr, attnOutBias, ln1Gamma, ln1Beta, + (const float *)gateWeight, nullptr, nullptr, nullptr, (const float *)upWeight, nullptr, nullptr, + nullptr, ln2Gamma, ln2Beta, (const float *)downWeight, nullptr, nullptr, false); + llama_layer_hub[llama_layer_key] = llama_layer; + printf(">> create llama_layer_key: %s\n", llama_layer_key.c_str()); + } else { + llama_layer = it_created->second; + } + + ctx->resize(batchSize, inputSeqLen, pastSeqLen); + hpj::Matrix actBuffers; + actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); + float *attnMask = prepareAttnMask(ctx, step); + + int workers = 1; + int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; + kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + + float *attnOut = (float *)(ctx->tmpBuf.Data()); + + llama_layer->forwardAttention(ctx, (float *)input, actBuffers.Data(), attnOut, attnMask, + presentKey, // presentKey, + presentValue, // presentValue, + inputSeqLen, // inputSeqLen, + pastSeqLen, // pastSeqLen + step == 0, // useSelfAttn, + true, // doLnBefore, + nullptr); + + llama_layer->forwardFFN(ctx, attnOut, (float *)output, inputStride, outputStride, true); +} + +void invokeLayerLLaMA(DataType dt, ActivationType at, NormType nt, int batchSize, int inputSeqLen, int attHeadDim, + int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, + int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, + const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, const float *ln2Beta, + const void *gateWeight, const void *upWeight, const void *downWeight, const float *queryBias = nullptr, + const float *keyBias = nullptr, const float *valueBias = nullptr, const float *attnOutBias = nullptr) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (dt == DataType::bf16) { + if (nt == NormType::RMS) + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias); + else if (nt == NormType::LN) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias); + } else { + printf(">> unsupported norm type\n"); + } + } else if (dt == DataType::fp16) { + if (nt == NormType::RMS) + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias); + else if (nt == NormType::LN) { + LayerLLaMAImpl(dt, at, nt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, output, + outputStride, input, inputStride, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateWeight, upWeight, downWeight, queryBias, keyBias, valueBias, + attnOutBias); + } else { + printf(">> unsupported norm type\n"); + } + } else { + printf(">> unsupported data type\n"); + } +} + +} // namespace xft diff --git a/tests/ut/layers_test.cpp b/tests/ut/layers_test.cpp new file mode 100644 index 00000000..c33ba0d1 --- /dev/null +++ b/tests/ut/layers_test.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include +#include +#include + +#include "bfloat16.h" +#include "float16.h" +#include "layers.h" +#include "gtest/gtest.h" + +template +static void compareLayerLLaMA(int step, int batchSize, int inputSeqLen, int pastSeqLen, int currentSeqLen, + int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int hiddenSize, + int intermediateSize, const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, + const void *keyWeight, const void *valueWeight, const void *attnOutWeight, const float *ln2Gamma, + const float *ln2Beta, const float *gateW, const float *upW, const float *downW) { + // Create input + float *input = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + float *ourOutput = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + memset(ourOutput, 0, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + + for (int i = 0; i < batchSize * inputSeqLen * hiddenSize; ++i) { + input[i] = static_cast(1.0f * rand() / RAND_MAX); + } + + xft::DataType dt = xft::DataType::unknown; + if constexpr (std::is_same::value) { + dt = xft::DataType::bf16; + } else if constexpr (std::is_same::value) { + dt = xft::DataType::fp16; + } else { + printf("Unsupported data type\n"); + GTEST_FAIL(); + return; + } + + auto start = std::chrono::high_resolution_clock::now(); + invokeLayerLLaMA(dt, xft::ActivationType::SILU, xft::NormType::RMS, batchSize, inputSeqLen, attHeadDim, attHeadNum, + kvHeadNum, maxPositions, maxPosEmbed, pastSeqLen, currentSeqLen, step, hiddenSize, intermediateSize, + (void *)ourOutput, hiddenSize, input, hiddenSize, ln1Gamma, ln1Beta, queryWeight, keyWeight, valueWeight, + attnOutWeight, ln2Gamma, ln2Beta, gateW, upW, downW); + auto end = std::chrono::high_resolution_clock::now(); + float during_time = std::chrono::duration(end - start).count(); + printf("[ RUNTIME ] XFT::invokeLayerLLaMA %.6f sec\n", during_time); + + free(input); + free(ourOutput); +} + +template +void test_LayerLLaMA(void) { + int maxPosEmbed = 4096; + int maxPositions = maxPosEmbed; + int hiddenSize = 4096; + int intermediateSize = 11008; + int attHeadNum = 32; + int attHeadDim = hiddenSize / attHeadNum; + int kvHeadNum = 32; + int qSize = attHeadDim * attHeadNum; + int kvSize = attHeadDim * kvHeadNum; + + float *ln1Gamma = (float *)aligned_alloc(64, hiddenSize * sizeof(float)); + float *ln1Beta = (float *)aligned_alloc(64, hiddenSize * sizeof(float)); + float *qkvProj = (float *)aligned_alloc(64, hiddenSize * (qSize + 2 * kvSize) * sizeof(float)); + float *oProj = (float *)aligned_alloc(64, hiddenSize * hiddenSize * sizeof(float)); + + float *ln2Gamma = (float *)aligned_alloc(64, hiddenSize * sizeof(float)); + float *ln2Beta = (float *)aligned_alloc(64, hiddenSize * sizeof(float)); + float *gateW = (float *)aligned_alloc(64, hiddenSize * intermediateSize * sizeof(float)); + float *upW = (float *)aligned_alloc(64, hiddenSize * intermediateSize * sizeof(float)); + float *downW = (float *)aligned_alloc(64, intermediateSize * hiddenSize * sizeof(float)); + + for (int i = 0; i < hiddenSize; ++i) { + ln1Gamma[i] = static_cast(0.5f * rand() / RAND_MAX); + ln1Beta[i] = static_cast(0.5f * rand() / RAND_MAX); + ln2Gamma[i] = static_cast(0.5f * rand() / RAND_MAX); + ln2Beta[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + for (int i = 0; i < hiddenSize * (qSize + 2 * kvSize); ++i) { + qkvProj[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + for (int i = 0; i < hiddenSize * hiddenSize; ++i) { + oProj[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + for (int i = 0; i < hiddenSize * intermediateSize; ++i) { + gateW[i] = static_cast(0.5f * rand() / RAND_MAX); + upW[i] = static_cast(0.5f * rand() / RAND_MAX); + downW[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + int step = 0; + int batchSize = 1; + int inputSeqLen = 18; + int pastSeqLen = 0; + int currentSeqLen = inputSeqLen; + int nextTokenNum = 1; + + compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, + qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); + pastSeqLen += inputSeqLen; + currentSeqLen = nextTokenNum; + compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, + qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); + pastSeqLen += nextTokenNum; + compareLayerLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, kvHeadNum, + maxPositions, maxPosEmbed, hiddenSize, intermediateSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize, + qkvProj + kvSize, oProj, ln2Gamma, ln2Beta, gateW, upW, downW); + + free(ln1Gamma); + free(ln1Beta); + free(qkvProj); + free(oProj); + free(ln2Gamma); + free(ln2Beta); + free(gateW); + free(upW); + free(downW); +} + +TEST(LayerLLaMA, bfloat16_t) { + test_LayerLLaMA(); +} + +TEST(LayerLLaMA, float16_t) { + test_LayerLLaMA(); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file