Skip to content

Commit

Permalink
Multi device kernel (#473)
Browse files Browse the repository at this point in the history
* multi device develop

* change build.sh

* cuda tools move into kernels

* multi device develop

* multi device develop

* multi device develop

* solve multi device develop

* solve arm compile error

* arm print_vec

* fix regress error

* solve cuda compile error

* solve multi device print vec

* split node class wit its derived class

* lightseq x86 unit test

* add x86 unit test

* Canonical Namespace

* add pybind compile

* Lsflow develop (#463)

* Fix new arch context build check(#441)

problem :
    LinearOp::forward is getting cublashandle without checking if context is built.
    LinearOp::backward is checking if the context is built before getting cublashandle.

solution:
    Modify LinearOp::forward to check if context is built before getting cublashandle.

* fix config reference bug (#453)

* developing lsflow

* add split_head op and its test (#454)

* lsflow develop

* format note message

* lsflow tune

* add notes for context class

* add note for lsflow

* op example

* Update CODEOWNERS (#457)

* Gpt infer (#456)

* add split head for beam search

* alter checkin

* make launch_transform_0213 more clear (#459)

* fix operator compile

* make launch_transform_0213 more clear (#460)

* change max_shape to max_shape_size

* correct bias shape in split_head (#461)

* add unit test for x86 cpu kernel

---------

Co-authored-by: Kangmo Kim <kangmo.kim@gmail.com>
Co-authored-by: Ying Xiong <xiongying.taka@bytedance.com>
Co-authored-by: Xiaohui Wang <wangxiaohui.neo@bytedance.com>

* fix sys.path

* fix sys.path (#466)

* jit build support pure cpu machine

* robust builder for x86 and cuda

* fix compile error

* develop test_ls_layer

* format

* fix training compile problem

* add mkl gemm for f32 and s8 (#470)

* test for encoder layer (#471)

* test encoder layer

* fix cuda free error

* lightseq multi device develop

* fix crf op error

* avoid import training directories

* fix strided_batch_gemm config data type

* add debug message

* convert shape from vector<int> to vector<size_t>

* change debug log format

* format

* remove useless dropout

* add sdpa layer into multi head attention layer

* fix conflict parameter: is_post_ln and pre_or_po...

* fix CMakeLists.txt compile

* multi kernel develop

* lightseq transformer.cu fix

* fix linear col/row major

* fix compile error

* fix lightseq post_ln network structure

* add transformer example & print error message

* add shape message

* format

* fix concat error

* add shape message

* fix beam search bug

---------

Co-authored-by: Kangmo Kim <kangmo.kim@gmail.com>
Co-authored-by: Ying Xiong <xiongying.taka@bytedance.com>
Co-authored-by: Xiaohui Wang <wangxiaohui.neo@bytedance.com>
  • Loading branch information
4 people committed Mar 2, 2023
1 parent 6bec6dd commit 2ead283
Show file tree
Hide file tree
Showing 245 changed files with 15,178 additions and 3,211 deletions.
126 changes: 90 additions & 36 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,48 +1,74 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(LightSeq LANGUAGES C CXX CUDA)

find_package(CUDA 11 REQUIRED)

option(USE_NEW_ARCH "inference with new arch" OFF)
option(FP16_MODE "inference with fp16" OFF)
option(DEBUG_MODE "debug computation result" OFF)
option(MEM_DEBUG "debug memory message" OFF)
option(DYNAMIC_API "build dynamic lightseq api library" OFF)
option(USE_TRITONBACKEND "build tritonbackend for lightseq" OFF)

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)

# setting compiler flags
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -Xcompiler -Wall")

if(DYNAMIC_API)
# dynamic link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared")
set(HDF5_USE_STATIC_LIBRARIES OFF)
else()
# static link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Static")
set(HDF5_USE_STATIC_LIBRARIES ON)
endif()

set(Protobuf_USE_STATIC_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
option(USE_PYBIND "build lightseq with pybind interface" OFF)

if(USE_NEW_ARCH)
add_definitions(-DNEW_ARCH)

set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86 87)
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)

# setting compiler flags
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -Xcompiler -Wall")

if(DYNAMIC_API)
# dynamic link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared")
set(HDF5_USE_STATIC_LIBRARIES OFF)
else()
# static link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Static")
set(HDF5_USE_STATIC_LIBRARIES ON)
endif()

set(Protobuf_USE_STATIC_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(DEVICE_ARCHITECTURES_LIST cuda x86 arm)
list(FIND DEVICE_ARCHITECTURES_LIST ${DEVICE_ARCH} DEVICE_INDEX)

if(DEVICE_INDEX EQUAL 0)
add_definitions(-DLIGHTSEQ_cuda)
set(DEVICE_ARCHITECTURE cuda)
elseif(DEVICE_INDEX EQUAL 1)
add_definitions(-DLIGHTSEQ_x86)
set(DEVICE_ARCHITECTURE x86)
elseif(DEVICE_INDEX EQUAL 2)
add_definitions(-DLIGHTSEQ_arm)
set(DEVICE_ARCHITECTURE arm)
else()
message(
WARNING "compiled with -DDEVICE_ARCHITECTURE=${DEVICE_ARCHITECTURE}")
message(
FATAL_ERROR
"-DDEVICE_ARCHITECTURE=\$\{device\} must in value of list [${DEVICE_ARCHITECTURES_LIST}]"
)
return()
endif()
message(STATUS "compile with device ${DEVICE_ARCHITECTURE} ${index}")

if(DEVICE_INDEX GREATER 0 AND FP16_MODE)
message(FATAL_ERROR "CPU device does not have fp16 version")
return()
endif()

if(DEBUG_MODE)
add_definitions(-DDEBUG_MODE)
set_option(MEM_DEBUG ON)
set(MEM_DEBUG ON)
message(STATUS "Build using debug mode")
endif()

Expand All @@ -61,13 +87,12 @@ if(USE_NEW_ARCH)
set(COMMON_HEADER_DIRS
${PROJECT_SOURCE_DIR}
${CUDA_PATH}/include
lightseq/csrc/kernels/includes
lightseq/csrc/kernels/${DEVICE_ARCHITECTURE}/includes
lightseq/csrc/layers_new/includes
lightseq/csrc/lsflow/includes
lightseq/csrc/models/includes
lightseq/csrc/ops_new/includes
lightseq/csrc/proto/includes
lightseq/csrc/tools/includes)
lightseq/csrc/proto/includes)

set(COMMON_LIB_DIRS ${CUDA_PATH}/lib64)

Expand All @@ -77,21 +102,50 @@ if(USE_NEW_ARCH)
link_directories(${COMMON_LIB_DIRS})

add_subdirectory(3rdparty/pybind11)
add_subdirectory(lightseq/csrc/tools)
add_subdirectory(lightseq/csrc/kernels)
add_subdirectory(lightseq/csrc/layers_new)
add_subdirectory(lightseq/csrc/kernels/${DEVICE_ARCHITECTURE})
add_subdirectory(lightseq/csrc/lsflow)
add_subdirectory(lightseq/csrc/models)
add_subdirectory(lightseq/csrc/ops_new)
add_subdirectory(lightseq/csrc/layers_new)
add_subdirectory(lightseq/csrc/models)
add_subdirectory(lightseq/csrc/proto)
add_subdirectory(lightseq/csrc/pybind)
add_subdirectory(lightseq/csrc/example)
if(USE_PYBIND)
message(STATUS "compile with pybind")
add_subdirectory(lightseq/csrc/pybind)
endif()
if(USE_TRITONBACKEND)
add_subdirectory(lightseq/csrc/triton_backend)
endif()

else()

find_package(CUDA 11 REQUIRED)

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)

# setting compiler flags
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -Xcompiler -Wall")

if(DYNAMIC_API)
# dynamic link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared")
set(HDF5_USE_STATIC_LIBRARIES OFF)
else()
# static link to cuda libraries and protobuf
set(CMAKE_CUDA_RUNTIME_LIBRARY "Static")
set(HDF5_USE_STATIC_LIBRARIES ON)
endif()

set(Protobuf_USE_STATIC_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_ARCHITECTURES 60 61 70 75 80 86)

set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDA_PATH}/include)
Expand Down
6 changes: 4 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
if [ ! -d 'build' ]; then
mkdir build
fi

cd build && cmake -DUSE_NEW_ARCH=OFF -DUSE_TRITONBACKEND=ON -DDEBUG_MODE=OFF -DFP16_MODE=ON -DMEM_DEBUG=OFF .. && make -j${nproc}
# DEVICE_ARCH could be cuda/x86/arm
cd build && cmake -DUSE_NEW_ARCH=ON -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=ON -DFP16_MODE=OFF -DMEM_DEBUG=OFF .. && make -j${nproc}
# you can use comand like below to compile lightseq with pybind interface:
# sudo PATH=$PATH:/usr/local/hdf5 CUDACXX=/usr/local/cuda/bin/nvcc DEVICE_ARCH=x86 ENABLE_FP32=1 ENABLE_DEBUG=0 ENABLE_NEW_ARCH=1 python3 setup.py install
2 changes: 1 addition & 1 deletion docs/build.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Build from source code

## Requirements
- cudatoolkit-dev >= 11, < 12
- cudatoolkit-dev >= 10.1 < 11
- protobuf >= 3.13
- cmake >= 3.18

Expand Down
8 changes: 4 additions & 4 deletions lightseq/csrc/example/bert_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int main(int argc, char* argv[]) {
}
}

auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel(
auto model = lightseq::LSModelFactory::GetInstance().CreateModel(
"Bert", model_weights_path, max_batch_size);

void* d_input;
Expand All @@ -57,10 +57,10 @@ int main(int argc, char* argv[]) {
std::cout << "infer preprocessing finished" << std::endl;

/* ---step5. infer and log--- */
for (int i = 0; i < 10; i++) {
for (int i = 0; i < 1; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
print_time_duration(start, "one infer time", 0);
// print_time_duration(start, "one infer time");
}

for (int i = 0; i < model->get_output_size(); i++) {
Expand All @@ -73,7 +73,7 @@ int main(int argc, char* argv[]) {
}
std::cout << std::endl;

print_vec(d_output, "output", 5);
lightseq::print_vec(d_output, "output", 5);
}

return 0;
Expand Down
91 changes: 91 additions & 0 deletions lightseq/csrc/example/transformer_example.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "model_base.h"
#include "util.h"

/**
@file
Example of how to run transformer inference using our implementation.
*/

int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];

std::vector<int> example_input = {63, 47, 65, 1507, 88, 74,
10, 2057, 362, 9, 284, 6};
int eg_seq_len = example_input.size();
int batch_size = 1, batch_seq_len = example_input.size();
if (argc == 4) {
batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}

int max_batch_size = std::max(4, batch_size);
std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < batch_seq_len; ++j) {
host_input.push_back(example_input[j % eg_seq_len]);
}
}

auto model = lightseq::LSModelFactory::GetInstance().CreateModel(
"Transformer", model_weights_path, max_batch_size);

void* d_input;
CHECK_GPU_ERROR(
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len));
CHECK_GPU_ERROR(cudaMemcpy(d_input, host_input.data(),
sizeof(int) * batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

// model->benchmark_mode(true);
model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});

for (int i = 0; i < model->get_output_size(); i++) {
void* d_output;
std::vector<int> shape = model->get_output_max_shape(i);
int total_size = 1;
for (int j = 0; j < shape.size(); j++) {
total_size *= shape[j];
}
CHECK_GPU_ERROR(cudaMalloc(&d_output, total_size * sizeof(int)));
model->set_output_ptr(i, d_output);
}
CHECK_GPU_ERROR(cudaStreamSynchronize(0));
std::cout << "infer preprocessing finished" << std::endl;

std::chrono::duration<double> elapsed;
int iter = 0;
/* ---step5. infer and log--- */
for (int i = 0; i < 1; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
auto finish = std::chrono::high_resolution_clock::now();
if (i >= 5) {
iter++;
elapsed += finish - start;
}
}

std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
<< " ms" << std::endl;

for (int i = 0; i < model->get_output_size(); i++) {
const void* d_output;
d_output = static_cast<const float*>(model->get_output_ptr(i));
std::vector<int> shape = model->get_output_shape(i);
std::cout << "output shape: ";
int size = 1;
for (int j = 0; j < shape.size(); j++) {
std::cout << shape[j] << " ";
size *= shape[j];
}
std::cout << std::endl;

if (!i)
lightseq::print_vec((int*)d_output, "output", size);
else
lightseq::print_vec((float*)d_output, "output", size);
}

return 0;
}
7 changes: 7 additions & 0 deletions lightseq/csrc/kernels/arm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)

cmake_minimum_required(VERSION 3.18)
set(lightseq_kernel_files gemm.cc utils.cc)

add_library(lightseq_kernels STATIC ${lightseq_kernel_files})
target_include_directories(lightseq_kernels INTERFACE includes)
5 changes: 5 additions & 0 deletions lightseq/csrc/kernels/arm/gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "kernel_headers.h"

namespace lightseq {
namespace arm {} // namespace arm
} // namespace lightseq
13 changes: 13 additions & 0 deletions lightseq/csrc/kernels/arm/includes/kernel_headers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <math_constants.h>
#include <type_traits>
#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <stdexcept>
#include <functional>

#include "utils.h"
9 changes: 9 additions & 0 deletions lightseq/csrc/kernels/arm/includes/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "cstdio"
#include "iostream"

namespace lightseq {

template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele);

}

0 comments on commit 2ead283

Please sign in to comment.