Skip to content

Commit

Permalink
Add gpuDNN lib and optimze rms_norm perf.
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 committed Mar 20, 2024
1 parent dbb4651 commit d8337e6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ dist/
/3rdparty/cmdline
/3rdparty/sentencepiece
/3rdparty/xdnn
/3rdparty/gpudnn


# MLServer
.envs
.metrics
.metrics
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,17 @@ endif()
include("cmake/mklml.cmake")
include("cmake/onednn.cmake")
include("cmake/xdnn.cmake")
include("cmake/gpudnn.cmake")
include("cmake/mkl.cmake")
include(GNUInstallDirs)

set(DEPEND_LIST "onednn" "xdnn_lib")
set(DEPEND_LIST "onednn" "xdnn_lib" "gpudnn_lib")

include_directories(${CMAKE_SOURCE_DIR}/3rdparty/)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/gpudnn)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/include)
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/src/kernels)
Expand All @@ -114,6 +116,7 @@ include_directories(${CMAKE_SOURCE_DIR}/src/common)
link_directories(${CMAKE_SOURCE_DIR}/src/kernels)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/src)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/gpudnn)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib)

find_package(oneCCL REQUIRED)
Expand Down Expand Up @@ -143,9 +146,11 @@ option(BUILD_WITH_SHARED_LIBS "Build with shared libraries" OFF)
if(BUILD_WITH_SHARED_LIBS)
message(STATUS "Notice: Building with shared libraries.")
list(APPEND 3RDPART_LIB_LIST "xdnn")
list(APPEND 3RDPART_LIB_LIST "gpu-dnn")
else()
message(STATUS "Notice: Building with static libraries.")
list(APPEND 3RDPART_LIB_LIST "xdnn_static")
list(APPEND 3RDPART_LIB_LIST "gpu-dnn")
endif()

# pipeline parallel feature
Expand Down
38 changes: 38 additions & 0 deletions cmake/gpudnn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

cmake_minimum_required(VERSION 3.18)

# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif()

project(dependency NONE)

include(ExternalProject)

# cmake-format: off
ExternalProject_Add(gpudnn_lib
URL https://github.com/intel/xFasterTransformer/releases/download/gpuDNN/gpudnn_v0.1.tar.gz
URL_HASH MD5=05b3554413e454ed027014e44a5c7fe4
TIMEOUT 60
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/gpudnn
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
# cmake-format: on
12 changes: 9 additions & 3 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,12 @@ class Attention {

// LayerNorm
this->norm.setWeight(gamma1, beta1, hiddenSize);
rmsNormWeight1 = sycl::malloc_device<float>(hiddenSize, *ctx->mmHelper->gpu_queue);
ctx->mmHelper->gpu_queue->memcpy(rmsNormWeight1, gamma1, hiddenSize * sizeof(float)).wait();
if constexpr (std::is_same_v<WeiT, float16_t>) {
float16_t gamma1_buf[hiddenSize];
float16_t::cvt_float_to_float16_MT(gamma1, gamma1_buf, hiddenSize);
rmsNormWeight1 = sycl::malloc_device<sycl::half>(hiddenSize, *ctx->mmHelper->gpu_queue);
ctx->mmHelper->gpu_queue->memcpy(rmsNormWeight1, gamma1_buf, hiddenSize * sizeof(WeiT)).wait();
}
}

#ifdef DEBUG
Expand Down Expand Up @@ -352,6 +356,7 @@ class Attention {
if (getScalingCoeff() != 0) { ctx->attFactor = getScalingCoeff(); }

TimeLine t4("MHA");
// FunTimer ft4;
if constexpr (!INPUT_AS_RESID) { // Swap inputBuffer and imBuffer
auto tmp = imBuffer.Data();
int rows = imBuffer.Rows(), cols = imBuffer.Cols(), stride = imBuffer.Stride();
Expand All @@ -373,6 +378,7 @@ class Attention {
else { fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen); }
}
t4.release();
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", ctx->mmHelper->gpu_index, "attention", ft4.elapsed());

// For multiple nodes inference, not the whole result buffer
hpj::Matrix<ImT> attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, imBuffer.Stride());
Expand Down Expand Up @@ -1040,7 +1046,7 @@ class Attention {
hpj::Vector<float> attnOutputWeightSum; // if weight is int8
hpj::Vector<float> attnOutputBias;

float *rmsNormWeight1;
sycl::half *rmsNormWeight1;

// Query/Key post op
QKPO_CLS qkpo;
Expand Down
10 changes: 7 additions & 3 deletions src/layers/mlp_llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
if (normW) {
normWeight.Resize(hiddenSize);
memcpy(normWeight.Data(), normW, sizeof(float) * hiddenSize);
rmsNormWeight2 = sycl::malloc_device<float>(hiddenSize, *ctx->mmHelper->gpu_queue);
ctx->mmHelper->gpu_queue->memcpy(rmsNormWeight2, normW, hiddenSize * sizeof(float)).wait();
if constexpr (std::is_same_v<WeiT, float16_t>) {
float16_t normWeight_buf[hiddenSize];
float16_t::cvt_float_to_float16_MT(normW, normWeight_buf, hiddenSize);
rmsNormWeight2 = sycl::malloc_device<sycl::half>(hiddenSize, *ctx->mmHelper->gpu_queue);
ctx->mmHelper->gpu_queue->memcpy(rmsNormWeight2, normWeight_buf, hiddenSize * sizeof(float16_t)).wait();
}
}
}

Expand Down Expand Up @@ -445,7 +449,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {

// LlamaRMSNorm param
hpj::Vector<float> normWeight;
float *rmsNormWeight2;
sycl::half *rmsNormWeight2;

#ifdef DEBUG
Debugger dbg;
Expand Down
28 changes: 19 additions & 9 deletions src/utils/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include <map>
#include <tuple>
#include <CL/sycl.hpp>
#include <dpct/device.hpp>
#include "gpu_layernorm_kernels.h"

#define USE_AMX_M 8

Expand Down Expand Up @@ -1294,23 +1296,29 @@ class MMHelper {
});
})
.wait();
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", gpu_index, "rope", t.elapsed());

// Reorder output
// FunTimer t2;
gpu_queue->memcpy(C_buf, packedC_buf, batchSize * seqLen * 3 * head_num * head_size * sizeof(float16_t)).wait();
float16_t::cvt_float16_to_float_MT(C_buf, query, batchSize * seqLen * 3 * head_num * head_size);
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", gpu_index, "rope", t.elapsed());
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", gpu_index, "memcpy", t2.elapsed());
}

void computeRMSNorm(float *output, const float *input, const float *weight, int rows, int cols) {
void computeRMSNorm(float *output, const float *input, const sycl::half *weight, int rows, int cols) {
// FunTimer t;
// float16_t I_buf[rows * cols];
// float16_t::cvt_float_to_float16_MT(input, I_buf, rows * cols);
float16_t *packedI_buf = packedI;
float16_t *packedA_buf = packedA;
sycl::half *packedI_buf = (sycl::half *)packedI;
sycl::half *packedA_buf = (sycl::half *)packedA;
// gpu_queue->memcpy(packedI_buf, I_buf, rows * cols * sizeof(float16_t)).wait();
rmsnorm_kernel(packedA_buf, packedI_buf, weight, rows, cols, cols, cols);
// rmsnorm_kernel(packedA_buf, packedI_buf, weight, rows, cols, cols, cols);
// gpu_queue->memcpy(I_buf, packedA_buf, rows * cols * sizeof(float16_t)).wait();
// float16_t::cvt_float16_to_float_MT(I_buf, output, rows * cols);
const float layernorm_eps = 1e-06;
fastertransformer::invokeGeneralT5LayerNorm(packedA_buf, packedI_buf, weight,
(const sycl::half*) nullptr, layernorm_eps,
rows, cols, gpu_queue);
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", gpu_index, "rmsnorm", t.elapsed());
}

Expand Down Expand Up @@ -1749,7 +1757,7 @@ class MMHelper {

if (copyC2G == true) {
// Reorder input
FunTimer t2;
// FunTimer t2;
float16_t A_buf[M * K];
float16_t::cvt_float_to_float16_MT(A, A_buf, M * K);
gpu_queue->memcpy(packed_input_mem.get_data_handle(), A_buf, M * K * sizeof(float16_t)).wait();
Expand All @@ -1765,15 +1773,17 @@ class MMHelper {
stream->wait();

if (postOp == true) {
// FunTimer t3;
if (M > 1)
sycl_sigmoid_mul(M, N / 2, packedC, ldc, packedC + N / 2, ldc, packedC, ldc);
else if (M == 1)
sycl_sigmoid_mul_M1(N / 2, packedC, ldc, packedC + N / 2, ldc, packedC, ldc);
// printf("xft_verbose,exec,gpu:%d,%s,%.6lf\n", gpu_index, "sycl_sigmoid_mul", t3.elapsed());
}

if (copyG2C == true) {
// Reorder output
FunTimer t3;
// FunTimer t3;
float16_t C_buf[M * N];
gpu_queue->memcpy(C_buf, packed_output_mem.get_data_handle(), M * N * sizeof(float16_t)).wait();
float16_t::cvt_float16_to_float_MT(C_buf, C, M * N);
Expand Down Expand Up @@ -1845,7 +1855,7 @@ class MMHelper {

if (copyC2G == true) {
// Reorder input
FunTimer t2;
// FunTimer t2;
float16_t A_buf[M * K];
// float16_t::cvt_float_to_float16_MT(A, A_buf, M * K);
#pragma omp parallel for
Expand Down Expand Up @@ -1879,7 +1889,7 @@ class MMHelper {

if (copyG2C == true) {
// Reorder output
FunTimer t4;
// FunTimer t4;
float16_t C_buf[M * N];
gpu_queue->memcpy(C_buf, packed_output_mem.get_data_handle(), M * N * sizeof(float16_t)).wait();
float16_t::cvt_float16_to_float_MT(C_buf, C, M * N);
Expand Down

0 comments on commit d8337e6

Please sign in to comment.