From b9305fb02440d5cd566d32b508bee9f9c13dda15 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 23 Apr 2024 02:28:04 +0000 Subject: [PATCH] add --- colossalai/kernel/extensions | 1 - colossalai/kernel/extensions/README.md | 140 ++++ colossalai/kernel/extensions/__init__.py | 35 + .../kernel/extensions/base_extension.py | 82 +++ colossalai/kernel/extensions/cpp_extension.py | 134 ++++ .../kernel/extensions/cpu_adam/__init__.py | 4 + .../extensions/cpu_adam/cpu_adam_arm.py | 41 ++ .../extensions/cpu_adam/cpu_adam_x86.py | 54 ++ colossalai/kernel/extensions/csrc/__init__.py | 11 + .../extensions/csrc/arm/cpu_adam_arm.cpp | 304 ++++++++ .../kernel/extensions/csrc/arm/cpu_adam_arm.h | 201 ++++++ .../kernel/extensions/csrc/common/micros.h | 224 ++++++ .../extensions/csrc/common/mp_type_traits.h | 38 + .../kernel/extensions/csrc/common/target.h | 134 ++++ .../extensions/csrc/cuda/activation_kernel.cu | 75 ++ .../cuda/context_kv_cache_memcpy_kernel.cu | 182 +++++ .../cuda/decode_kv_cache_memcpy_kernel.cu | 162 +++++ .../extensions/csrc/cuda/funcs/cast_functor.h | 74 ++ .../extensions/csrc/cuda/funcs/op_functor.h | 92 +++ .../cuda/fused_rotary_emb_and_cache_kernel.cu | 481 ++++++++++++ .../csrc/cuda/get_cos_and_sin_kernel.cu | 215 ++++++ .../csrc/cuda/include/block_reduce.h | 184 +++++ .../extensions/csrc/cuda/layer_norm_kernel.cu | 683 ++++++++++++++++++ .../kernel/extensions/csrc/cuda/moe_kernel.cu | 662 +++++++++++++++++ .../csrc/cuda/multi_tensor_adam_kernel.cu | 146 ++++ .../csrc/cuda/multi_tensor_apply.cuh | 130 ++++ .../csrc/cuda/multi_tensor_l2norm_kernel.cu | 387 ++++++++++ .../csrc/cuda/multi_tensor_lamb_kernel.cu | 354 +++++++++ .../csrc/cuda/multi_tensor_scale_kernel.cu | 125 ++++ .../csrc/cuda/multi_tensor_sgd_kernel.cu | 167 +++++ .../extensions/csrc/cuda/pybind/inference.cpp | 84 +++ .../csrc/cuda/pybind/layer_norm.cpp | 141 ++++ .../extensions/csrc/cuda/pybind/moe.cpp | 97 +++ .../extensions/csrc/cuda/pybind/optimizer.cpp | 49 ++ .../cuda/pybind/scaled_masked_softmax.cpp | 70 ++ .../scaled_upper_triang_masked_softmax.cpp | 54 ++ .../csrc/cuda/rms_layernorm_kernel.cu | 426 +++++++++++ .../csrc/cuda/scaled_masked_softmax.h | 500 +++++++++++++ .../csrc/cuda/scaled_masked_softmax_kernel.cu | 89 +++ .../cuda/scaled_upper_triang_masked_softmax.h | 538 ++++++++++++++ ...aled_upper_triang_masked_softmax_kernel.cu | 75 ++ .../csrc/cuda/utils/gpu_launch_config.h | 78 ++ .../extensions/csrc/cuda/utils/micros.h | 18 + .../csrc/cuda/utils/nvgpu_dev_info.h | 60 ++ .../csrc/cuda/utils/vec_type_traits.h | 83 +++ .../csrc/cuda/utils/vector_copy_utils.h | 52 ++ .../kernel/extensions/csrc/scaled_softmax.py | 190 +++++ .../kernel/extensions/csrc/x86/cpu_adam.cpp | 446 ++++++++++++ .../kernel/extensions/csrc/x86/cpu_adam.h | 185 +++++ .../kernel/extensions/cuda_extension.py | 109 +++ .../extensions/flash_attention/__init__.py | 14 + .../flash_attention_dao_cuda.py | 96 +++ .../flash_attention/flash_attention_npu.py | 62 ++ .../flash_attention_sdpa_cuda.py | 56 ++ .../kernel/extensions/inference/__init__.py | 3 + .../inference/inference_ops_cuda.py | 35 + .../kernel/extensions/layernorm/__init__.py | 3 + .../extensions/layernorm/layernorm_cuda.py | 24 + colossalai/kernel/extensions/moe/__init__.py | 3 + colossalai/kernel/extensions/moe/moe_cuda.py | 29 + .../kernel/extensions/optimizer/__init__.py | 3 + .../optimizer/fused_optimizer_cuda.py | 34 + .../kernel/extensions/softmax/__init__.py | 4 + .../softmax/scaled_masked_softmax_cuda.py | 32 + ...aled_upper_triangle_masked_softmax_cuda.py | 34 + .../kernel/extensions/triton_extension.py | 21 + colossalai/kernel/extensions/utils.py | 229 ++++++ colossalai/shardformer/layer/linear.py | 1 + examples/language/llama2/attn.py | 353 ++++++++- 69 files changed, 9900 insertions(+), 2 deletions(-) delete mode 120000 colossalai/kernel/extensions create mode 100644 colossalai/kernel/extensions/README.md create mode 100644 colossalai/kernel/extensions/__init__.py create mode 100644 colossalai/kernel/extensions/base_extension.py create mode 100644 colossalai/kernel/extensions/cpp_extension.py create mode 100644 colossalai/kernel/extensions/cpu_adam/__init__.py create mode 100644 colossalai/kernel/extensions/cpu_adam/cpu_adam_arm.py create mode 100644 colossalai/kernel/extensions/cpu_adam/cpu_adam_x86.py create mode 100644 colossalai/kernel/extensions/csrc/__init__.py create mode 100644 colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.cpp create mode 100644 colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.h create mode 100644 colossalai/kernel/extensions/csrc/common/micros.h create mode 100644 colossalai/kernel/extensions/csrc/common/mp_type_traits.h create mode 100644 colossalai/kernel/extensions/csrc/common/target.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/activation_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/funcs/cast_functor.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/funcs/op_functor.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/get_cos_and_sin_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/include/block_reduce.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/layer_norm_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/moe_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_adam_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_apply.cuh create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_scale_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/inference.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/layer_norm.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/moe.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/optimizer.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp create mode 100644 colossalai/kernel/extensions/csrc/cuda/rms_layernorm_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu create mode 100644 colossalai/kernel/extensions/csrc/cuda/utils/gpu_launch_config.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/utils/micros.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/utils/nvgpu_dev_info.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/utils/vec_type_traits.h create mode 100644 colossalai/kernel/extensions/csrc/cuda/utils/vector_copy_utils.h create mode 100644 colossalai/kernel/extensions/csrc/scaled_softmax.py create mode 100644 colossalai/kernel/extensions/csrc/x86/cpu_adam.cpp create mode 100644 colossalai/kernel/extensions/csrc/x86/cpu_adam.h create mode 100644 colossalai/kernel/extensions/cuda_extension.py create mode 100644 colossalai/kernel/extensions/flash_attention/__init__.py create mode 100644 colossalai/kernel/extensions/flash_attention/flash_attention_dao_cuda.py create mode 100644 colossalai/kernel/extensions/flash_attention/flash_attention_npu.py create mode 100644 colossalai/kernel/extensions/flash_attention/flash_attention_sdpa_cuda.py create mode 100644 colossalai/kernel/extensions/inference/__init__.py create mode 100644 colossalai/kernel/extensions/inference/inference_ops_cuda.py create mode 100644 colossalai/kernel/extensions/layernorm/__init__.py create mode 100644 colossalai/kernel/extensions/layernorm/layernorm_cuda.py create mode 100644 colossalai/kernel/extensions/moe/__init__.py create mode 100644 colossalai/kernel/extensions/moe/moe_cuda.py create mode 100644 colossalai/kernel/extensions/optimizer/__init__.py create mode 100644 colossalai/kernel/extensions/optimizer/fused_optimizer_cuda.py create mode 100644 colossalai/kernel/extensions/softmax/__init__.py create mode 100644 colossalai/kernel/extensions/softmax/scaled_masked_softmax_cuda.py create mode 100644 colossalai/kernel/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py create mode 100644 colossalai/kernel/extensions/triton_extension.py create mode 100644 colossalai/kernel/extensions/utils.py mode change 120000 => 100644 examples/language/llama2/attn.py diff --git a/colossalai/kernel/extensions b/colossalai/kernel/extensions deleted file mode 120000 index e8eb45a54893..000000000000 --- a/colossalai/kernel/extensions +++ /dev/null @@ -1 +0,0 @@ -../../extensions \ No newline at end of file diff --git a/colossalai/kernel/extensions/README.md b/colossalai/kernel/extensions/README.md new file mode 100644 index 000000000000..b9bde7742be9 --- /dev/null +++ b/colossalai/kernel/extensions/README.md @@ -0,0 +1,140 @@ +# 🔌 Extensions + +## 📌 Table of Contents + +- [🔌 Extensions](#-extensions) + - [📌 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [🪅 Design](#-design) + - [🛠 API Usage](#-api-usage) + - [🏗 Write a customized extension](#-write-a-customized-extension) + - [✏️ Acknowledgement](#️-acknowledgement) + +## 📚 Introduction + +This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the `extensions` module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below. + +## 🪅 Design + +The `extensions` module is a sub-module of the `colossalai.kernel` module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the `colossalai.kernel.extensions` path for runtime build. + +As we want to support multi-backend kernels, we have to consider multiple compiler options such as `torch.jit`, `CUDA`, `triton` and multiple hardware backends such as `CPU`, `GPU` and `NPU`. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel. + +For example, if the user wants to use the CPU Adam kernel, he can just call `load()` on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +![](https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/extensions.png?raw=true) + +## 🛠 API Usage + +To make the `colossalai.kernel` easy to use, we expose some simple APIs and you can use them based on your scenario. + +- Case 1: Simply load a kernel + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +- Case 2: Load a specific kernel + +This case applies if you are familiar with the extensions available. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel by giving the kernel name +kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm") +``` + +- Case 3: Register your own extension + +This case applies if you know how to write an extension. If you do not know how, you can refer to the section below. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader +from colossalai.kernel.base_extension import _Extension + +# create your own extension class +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + # implementation here + ... + +# register your extension +# you can use the priority value to make sure your kernel will be loaded by default +CPUAdamLoader.register_extension(MyExtension) + +# load the kernel +kernel = CPUAdamLoader().load() +``` + +## 🏗 Write a customized extension + +It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly. + +You just need to inherit the `_Extension` base class or other backend-specific classes such as `_CudaExtension` and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend. + +```python +from colossalai.kernel.base_extension import _Extension + + +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + def is_available(self) -> bool: + """ + Return if the required hardware can be found. + """ + ... + + def assert_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + ... + + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + """ + If this kernel can be built AOT, it should return an extension object + to Python setuptools for compilation. + """ + ... + + def build_jit(self) -> Callable: + """ + Build extension kernel just in time. + """ + ... + + def load(self): + """ + The API called by the user to get the kernel. + """ + ... + +``` + +## ✏️ Acknowledgement + +This module is written from scratch but we learnt a lot by looking into [DeepSpeed' +s op_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder). We wish to acknowledge their great work and contributions to the open-source community. diff --git a/colossalai/kernel/extensions/__init__.py b/colossalai/kernel/extensions/__init__.py new file mode 100644 index 000000000000..1e936eec69cc --- /dev/null +++ b/colossalai/kernel/extensions/__init__.py @@ -0,0 +1,35 @@ +from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension +from .inference import InferenceOpsCudaExtension +from .layernorm import LayerNormCudaExtension +from .moe import MoeCudaExtension +from .optimizer import FusedOptimizerCudaExtension +from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension + +ALL_EXTENSIONS = [ + CpuAdamArmExtension, + CpuAdamX86Extension, + LayerNormCudaExtension, + MoeCudaExtension, + FusedOptimizerCudaExtension, + InferenceOpsCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionSdpaCudaExtension, + FlashAttentionNpuExtension, +] + +__all__ = [ + "CpuAdamArmExtension", + "CpuAdamX86Extension", + "LayerNormCudaExtension", + "MoeCudaExtension", + "FusedOptimizerCudaExtension", + "InferenceOpsCudaExtension", + "ScaledMaskedSoftmaxCudaExtension", + "ScaledUpperTriangleMaskedSoftmaxCudaExtension", + "FlashAttentionDaoCudaExtension", + "FlashAttentionSdpaCudaExtension", + "FlashAttentionNpuExtension", +] diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py new file mode 100644 index 000000000000..0c79c0a9e9f5 --- /dev/null +++ b/colossalai/kernel/extensions/base_extension.py @@ -0,0 +1,82 @@ +import hashlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Union + +__all__ = ["_Extension"] + + +class _Extension(ABC): + def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1): + self._name = name + self._support_aot = support_aot + self._support_jit = support_jit + self.priority = priority + + @property + def name(self): + return self._name + + @property + def support_aot(self): + return self._support_aot + + @property + def support_jit(self): + return self._support_jit + + @staticmethod + def get_jit_extension_folder_path(): + """ + Kernels which are compiled during runtime will be stored in the same cache folder for reuse. + The folder is in the path ~/.cache/colossalai/torch_extensions/. + The name of the follows a common format: + torch._- + + The suffix is the hash value of the path of the `colossalai` file. + """ + import torch + + import colossalai + from colossalai.accelerator import get_accelerator + + # get torch version + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + + # get device version + device_name = get_accelerator().name + device_version = get_accelerator().get_version() + + # use colossalai's file path as hash + hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest() + + # concat + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}" + cache_directory = os.path.join(home_directory, extension_directory) + return cache_directory + + @abstractmethod + def is_available(self) -> bool: + """ + Check if the hardware required by the kernel is available. + """ + + @abstractmethod + def assert_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + + @abstractmethod + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + pass + + @abstractmethod + def build_jit(self) -> Callable: + pass + + @abstractmethod + def load(self) -> Callable: + pass diff --git a/colossalai/kernel/extensions/cpp_extension.py b/colossalai/kernel/extensions/cpp_extension.py new file mode 100644 index 000000000000..3adb65fb8f4e --- /dev/null +++ b/colossalai/kernel/extensions/cpp_extension.py @@ -0,0 +1,134 @@ +import importlib +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension + +__all__ = ["_CppExtension"] + + +class _CppExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=True, support_jit=True, priority=priority) + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op = None + + # build-related variables + self.prebuilt_module_path = "colossalai._C" + self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("csrc"), path) + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + + # get the current file path + # iteratively check the parent directory + # if the parent directory is "extensions", then the current file path is the root directory + # otherwise, the current file path is inside the root directory + current_file_path = Path(__file__) + while True: + if current_file_path.name == "extensions": + break + else: + current_file_path = current_file_path.parent + extension_module_path = current_file_path + code_abs_path = extension_module_path.joinpath(code_path) + return str(code_abs_path) + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def build_aot(self) -> "CppExtension": + from torch.utils.cpp_extension import CppExtension + + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + def build_jit(self) -> None: + from torch.utils.cpp_extension import load + + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + def load(self): + try: + op_kernel = self.import_op() + except (ImportError, ModuleNotFoundError): + # if import error occurs, it means that the kernel is not pre-built + # so we build it jit + op_kernel = self.build_jit() + + return op_kernel diff --git a/colossalai/kernel/extensions/cpu_adam/__init__.py b/colossalai/kernel/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000000..d5c69e902a80 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/__init__.py @@ -0,0 +1,4 @@ +from .cpu_adam_arm import CpuAdamArmExtension +from .cpu_adam_x86 import CpuAdamX86Extension + +__all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"] diff --git a/colossalai/kernel/extensions/cpu_adam/cpu_adam_arm.py b/colossalai/kernel/extensions/cpu_adam/cpu_adam_arm.py new file mode 100644 index 000000000000..61c4f3ed0697 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/cpu_adam_arm.py @@ -0,0 +1,41 @@ +import platform + +from ..cpp_extension import _CppExtension + + +class CpuAdamArmExtension(_CppExtension): + def __init__(self): + super().__init__(name="cpu_adam_arm") + + def is_available(self) -> bool: + # only arm allowed + return platform.machine() == "aarch64" + + def assert_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "aarch64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/colossalai/kernel/extensions/cpu_adam/cpu_adam_x86.py b/colossalai/kernel/extensions/cpu_adam/cpu_adam_x86.py new file mode 100644 index 000000000000..4789f2f32665 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/cpu_adam_x86.py @@ -0,0 +1,54 @@ +import platform + +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads + + +class CpuAdamX86Extension(_CudaExtension): + def __init__(self): + super().__init__(name="cpu_adam_x86") + + def is_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_available() + + def assert_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_compatible() + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("x86/cpu_adam.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-lcudart", + "-lcublas", + "-g", + "-Wno-reorder", + "-fopenmp", + "-march=native", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/csrc/__init__.py b/colossalai/kernel/extensions/csrc/__init__.py new file mode 100644 index 000000000000..0eac28d23e24 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/__init__.py @@ -0,0 +1,11 @@ +from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .multihead_attention import MultiHeadAttention +from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax + +__all__ = [ + "LayerNorm", + "MultiHeadAttention", + "FusedScaleMaskSoftmax", + "ScaledUpperTriangMaskedSoftmax", + "AttnMaskType", +] diff --git a/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.cpp b/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.cpp new file mode 100644 index 000000000000..a715a2711576 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.cpp @@ -0,0 +1,304 @@ +#include "cpu_adam_arm.h" + +void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4 = vdivq_f32(grad_4, loss_scale_vec); + } + float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i); + float32x4_t variance_4 = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i); + float32x4_t param_4 = simd_load_offset(_params, param_dtype, i); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4); + } + momentum_4 = vmulq_f32(momentum_4, betta1_4); + momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4); + variance_4 = vmulq_f32(variance_4, betta2_4); + grad_4 = vmulq_f32(grad_4, grad_4); + variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4); + grad_4 = vsqrtq_f32(variance_4); + grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt); + grad_4 = vdivq_f32(momentum_4, grad_4); + if (_weight_decay > 0 && _adamw_mode) { + param_4 = vfmaq_f32(param_4, param_4, weight_decay_4); + } + param_4 = vfmaq_f32(param_4, grad_4, step_size_4); + simd_store_offset(_params, param_dtype, param_4, i); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = scalar_load_offset(grads, grad_dtype, k); + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = scalar_load_offset(_params, param_dtype, k); + float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k); + float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k); + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + scalar_store_offset(_params, param_dtype, param, k); + scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k); + scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k); + } + } + } +} + +void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + float32x4_t grad_4[4]; + float32x4_t momentum_4[4]; + float32x4_t variance_4[4]; + float32x4_t param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_1(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + float32x4_t grad_4[8]; + float32x4_t momentum_4[8]; + float32x4_t variance_4[8]; + float32x4_t param_4[8]; +#pragma unroll 4 + for (int j = 0; j < 8; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_4(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(), + exp_avg_sq_c.data_ptr(), params_c.numel(), + params_c.scalar_type(), grads_c.scalar_type(), + exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &AdamOptimizer::step); +} diff --git a/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.h b/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.h new file mode 100644 index 000000000000..c731850edc31 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/arm/cpu_adam_arm.h @@ -0,0 +1,201 @@ +#pragma once +#include +#include + +#include + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__aarch64__) +#include +#define SIMD_WIDTH 4 + +inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + return vld1q_f32(ptr_f + offset); + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + return vcvt_f32_f16(vld1_f16(ptr_h + offset)); + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // return vcvt_f32_bf16(vld1_bf16(ptr_b + offset)); + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) { + return simd_load_offset(ptr, dtype, 0); +} + +inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + vst1q_f32(ptr_f + offset, data); + break; + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + vst1_f16(ptr_h + offset, vcvt_f16_f32(data)); + break; + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data)); + // break; + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) { + return simd_store_offset(ptr, dtype, data, 0); +} + +inline float32x4_t simd_set(float value) { + auto val = static_cast(value); + return vdupq_n_f32(val); +} + +#endif + +inline float scalar_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return *(reinterpret_cast(ptr) + offset); + case at::ScalarType::Half: + return static_cast( + *(reinterpret_cast(ptr) + offset)); + // case at::ScalarType::BFloat16: + // return static_cast( + // *(reinterpret_cast(ptr) + offset)); + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + *(reinterpret_cast(ptr) + offset) = data; + break; + case at::ScalarType::Half: + *(reinterpret_cast(ptr) + offset) = data; + break; + // case at::ScalarType::BFloat16: + // *(reinterpret_cast(ptr) + offset) = data; + break; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return reinterpret_cast(ptr) + offset; + case at::ScalarType::Half: + return reinterpret_cast(ptr) + offset; + // case at::ScalarType::BFloat16: + // return reinterpret_cast(ptr) + offset; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +#define STEP(SPAN) \ + void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \ + void *_exp_avg_sq, size_t _param_size, \ + at::ScalarType param_dtype, at::ScalarType grad_dtype, \ + at::ScalarType exp_avg_dtype, \ + at::ScalarType exp_avg_sq_dtype, float loss_scale = -1); + +class AdamOptimizer { + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; + + public: + AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~AdamOptimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); +}; diff --git a/colossalai/kernel/extensions/csrc/common/micros.h b/colossalai/kernel/extensions/csrc/common/micros.h new file mode 100644 index 000000000000..fd489d764127 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/common/micros.h @@ -0,0 +1,224 @@ +/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ + +#pragma once + +#include + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ + TYPE, NAME, ...) \ + if (HIGH_PRECISION) { \ + const bool high_precision = true; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + } else { \ + const bool high_precision = false; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case at::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ + if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::Half) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Half && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else { \ + AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ + "'"); \ + } diff --git a/colossalai/kernel/extensions/csrc/common/mp_type_traits.h b/colossalai/kernel/extensions/csrc/common/mp_type_traits.h new file mode 100644 index 000000000000..5275732194ab --- /dev/null +++ b/colossalai/kernel/extensions/csrc/common/mp_type_traits.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include "micros.h" + +namespace colossalAI { +namespace common { + +template +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +template +struct ScalarTypeTrait { + using Type = + typename std::conditional::Type, + T>::type; +}; + +} // namespace common +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/common/target.h b/colossalai/kernel/extensions/csrc/common/target.h new file mode 100644 index 000000000000..ee3072f62d71 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/common/target.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include + +namespace colossalAI { +namespace common { + +class Target { + public: + enum class OS : int { + Unk = -1, + Linux, + Windows, + }; + enum class Arch : int { + Unk = -1, + X86, + Arm, + NVGPU, + AMDGPU, + Ascend, + }; + enum class BitLen : int { + Unk = -1, + k32, + k64, + }; + + explicit Target(OS os, Arch arch, BitLen bitlen) + : os_(os), arch_(arch), bitlen_(bitlen) {} + + bool defined() const { + return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk); + } + + std::string str() const { + std::string s{"OS: "}; + switch (os_) { + case OS::Unk: + s += "Unk"; + break; + case OS::Linux: + s += "Linux"; + break; + case OS::Windows: + s += "Windows"; + break; + default: + throw std::invalid_argument("Invalid OS type!"); + } + s += "\t"; + s += "Arch: "; + + switch (arch_) { + case Arch::Unk: + s += "Unk"; + break; + case Arch::X86: + s += "X86"; + break; + case Arch::Arm: + s += "Arm"; + break; + case Arch::NVGPU: + s += "NVGPU"; + break; + case Arch::AMDGPU: + s += "AMDGPU"; + break; + case Arch::Ascend: + s += "Ascend"; + break; + default: + throw std::invalid_argument("Invalid Arch type!"); + } + s += "\t"; + s += "BitLen: "; + + switch (bitlen_) { + case BitLen::Unk: + s += "Unk"; + break; + case BitLen::k32: + s += "k32"; + break; + case BitLen::k64: + s += "k64"; + break; + default: + throw std::invalid_argument("Invalid target bit length!"); + } + + return s; + } + + OS os() const { return os_; } + Arch arch() const { return arch_; } + BitLen bitlen() const { return bitlen_; } + + static Target DefaultX86Target(); + static Target DefaultArmTarget(); + static Target DefaultRocmTarget(); + static Target DefaultAscendTarget(); + + static Target DefaultCUDATarget() { + return Target(OS::Linux, Arch::NVGPU, BitLen::k64); + } + + friend std::ostream& operator<<(std::ostream& os, const Target& target); + friend bool operator==(const Target& lhs, const Target& rhs); + friend bool operator!=(const Target& lhs, const Target& rhs); + + private: + OS os_{OS::Unk}; + Arch arch_{Arch::Unk}; + BitLen bitlen_{BitLen::Unk}; +}; + +std::ostream& operator<<(std::ostream& os, const Target& target) { + std::cout << target.str() << std::endl; +} +bool operator==(const Target& lhs, const Target& rhs) { + return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) && + (lhs.bitlen_ == rhs.bitlen_); +} +bool operator!=(const Target& lhs, const Target& rhs) { + return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) && + (lhs.bitlen_ != rhs.bitlen_); +} + +} // namespace common +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/activation_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/activation_kernel.cu new file mode 100644 index 000000000000..372b303875cb --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/activation_kernel.cu @@ -0,0 +1,75 @@ +#include +#include +#include + +#include "../common/micros.h" +#include "../common/mp_type_traits.h" + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + using MT = typename colossalAI::common::MPTypeTrait::Type; + return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); +} + +template +__global__ void act_and_mul_kernel( + const scalar_t* __restrict__ ins_data, + scalar_t* __restrict__ outs_data, + const int64_t numel) { + using MT = typename colossalAI::common::MPTypeTrait::Type; + + int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); + const int64_t grid_size = blockDim.x * gridDim.x; + if(idx > numel) { + return; + } + + for(int64_t i = idx; i < numel; i += grid_size) { + scalar_t x = ins_data[i]; + scalar_t y = ins_data[i+numel]; + outs_data[i] = static_cast(static_cast(ACT_FN(x)) * static_cast(y)); + } +} + +// Note(LiuYang):This func is designed for calculation mode like +// silu(x[:half_1stdim]) * (x[half_1stdim:]) +torch::Tensor silu_and_mul(const torch::Tensor& ins) +{ + // Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api + // to manipulate ins_shape which is IntArrayRef + auto ins_shape = ins.sizes().vec(); + + ins_shape[0] = ins_shape[0]/2; + if (ins_shape[0] == 1) { + ins_shape.erase(ins_shape.begin()); + } + auto outs = torch::zeros(ins_shape,ins.options()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Note(Liuyang): numel of ins must be divisible by 2 + int64_t numel = ((torch::numel(ins)) >> 1); + + // Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now + // I comment this part code,because it also cost a little time to calculate a better config + // colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); + // auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); + // dim3 grid = config.grid; + // dim3 block = config.block; + + dim3 grid((numel+255)/256); + dim3 block(256); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + ins.scalar_type(), + "silu_and_mul", + act_and_mul_kernel><<>>( + ins.data_ptr(), + outs.data_ptr(), + numel + );) + + AT_CUDA_CHECK(cudaGetLastError()); + return outs; +} diff --git a/colossalai/kernel/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000000..3300fad47796 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -0,0 +1,182 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" + +template +__global__ void context_kv_cache_memcpy_kernel( + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cu_seqlens, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int batch_size, + const int block_table_stride, + const int64_t key_stride, + const int64_t value_stride +) +{ + const int seq_token_id = blockIdx.x; + const int seq_id = blockIdx.y; + const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; + + if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { + return ; + } + + const int block_offset = seq_token_id % block_size; + const int hidden_size = head_num * head_dim; + const int total_token_id = cu_seqlens[seq_id] + seq_token_id; + int head_id; + int head_offset; + int64_t key_src_id; + int64_t value_src_id; + int64_t target_id; + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); + } + + // tail process + if (!Aligned) { + for (; i < hidden_size; ++i ) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + } + +} + +template +void apply_context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(2); + int batch_size = block_tables.size(0); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(max_seq_len_in_batch, batch_size); + dim3 block(std::min(thread_nums, 512)); + +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + context_kv_cache_memcpy_kernel<<>>( \ + key.data_ptr(), \ + value.data_ptr(), \ + key_cache.data_ptr(), \ + value_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + cu_seqlens.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + batch_size, \ + block_table_stride, \ + key_stride, \ + value_stride \ + ); \ + } while(0) + +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + + if (aligned) { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false); + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "context_kv_cache_memcpy", + apply_context_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + cu_seqlens, + block_tables, + max_seq_len_in_batch + );) +} diff --git a/colossalai/kernel/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000000..3fcceac6b942 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -0,0 +1,162 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" + +template +__global__ void decode_kv_cache_memcpy_kernel( + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int64_t key_stride, + const int64_t value_stride, + const int block_table_stride +) +{ + const int seq_id = blockIdx.x; + const int seq_len = sequence_lengths[seq_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size]; + const int hidden_size = head_num * head_dim; + + if ( block_id < 0 ) { + return ; + } + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); + } + + if (!Aligned) { + for (; i < hidden_size; ++i ) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + } + +} + +template +void apply_decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(2); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + decode_kv_cache_memcpy_kernel<<>>( \ + key.data_ptr(), \ + value.data_ptr(), \ + key_cache.data_ptr(), \ + value_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + key_stride, \ + value_stride, \ + block_table_stride \ + ); \ + } while(0) + +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \ + do { \ + switch (__vec_size) { \ + case 1: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", __vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size); + } + else { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size); + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "decode_kv_cache_memcpy", + apply_decode_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} diff --git a/colossalai/kernel/extensions/csrc/cuda/funcs/cast_functor.h b/colossalai/kernel/extensions/csrc/cuda/funcs/cast_functor.h new file mode 100644 index 000000000000..623e1cdeb290 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/funcs/cast_functor.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../utils/micros.h" + +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16 + +namespace colossalAI { +namespace cuda { +namespace funcs { + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = at::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = at::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +template +struct CastFunctor : public std::unary_function { + HOSTDEVICE To operator()(From val) { return static_cast(val); } +}; + +#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \ + FUNCTION_MODIFIER) \ + template <> \ + struct CastFunctor : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \ + }; + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162, + __float2bfloat162_rn(val), DEVICE) + +#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/funcs/op_functor.h b/colossalai/kernel/extensions/csrc/cuda/funcs/op_functor.h new file mode 100644 index 000000000000..0398ea97b539 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/funcs/op_functor.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../utils/micros.h" + +namespace colossalAI { +namespace cuda { +namespace funcs { + +enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; + +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16 +template +struct BinaryOpFunctor; + +#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \ + FUNCTION_MODIFIER, ARGS...) \ + template \ + struct BinaryOpFunctor \ + : public std::binary_function { \ + FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \ + }; + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs), + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs), + HOSTDEVICE, typename T) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd, + __hadd(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd, + __hadd2(lhs, rhs), DEVICE) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, + __hadd(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd, + __hadd2(lhs, rhs), DEVICE) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, + __float2bfloat16(__bfloat162float(lhs) + + __bfloat162float(rhs)), + DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, BinaryOpType::kAdd, + __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), + __high2float(lhs) + __high2float(rhs)), + DEVICE) +#endif + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul, + __hmul(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul, + __hmul2(lhs, rhs), DEVICE) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, + __hmul(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul, + __hmul2(lhs, rhs), DEVICE) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, + __float2bfloat16(__bfloat162float(lhs) * + __bfloat162float(rhs)), + DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, BinaryOpType::kMul, + __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), + __high2float(lhs) * __high2float(rhs)), + DEVICE) +#endif + +#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION + +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu new file mode 100644 index 000000000000..8feb6b343620 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -0,0 +1,481 @@ +// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" +#include "../common/mp_type_traits.h" + +template +__device__ void apply_emb_rotary_compute( + scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, + const m_scalar_t* __restrict__ sin_ptr, const int64_t stride, + const int token_id, const int shard_block_size, const int half_head_dim, + const int head_num, const int head_dim) { + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * stride + (i / half_head_dim) * head_dim + head_offset; + + copy_vector(x, src + addr_offset); + copy_vector(y, src + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); + } + + copy_vector(src + addr_offset, out_x); + copy_vector(src + addr_offset + half_head_dim, out_y); + } +} + +template +__device__ void apply_kv_memcopy( + scalar_t* __restrict__ src, scalar_t* __restrict__ cache, + const int64_t stride, const int token_id, const int block_id, + const int hidden_size, const int block_size, const int block_offset, + const int head_dim, const int half_head_dim) { + for (int i = threadIdx.x * VecSize; i < hidden_size / 2; + i += blockDim.x * VecSize) { + const int head_id = i / half_head_dim; + const int head_offset = i % half_head_dim; + const int64_t src_id = token_id * stride + head_id * head_dim + head_offset; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(cache + target_id, src + src_id); + copy_vector(cache + target_id + half_head_dim, + src + src_id + half_head_dim); + } +} + +template +__device__ void cos_sin_memory_access( + const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, + m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id, + const int shard_block_size, const int cos_stride, const int sin_stride, + const int half_head_dim) { + for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { + // We assume that the value of head_dim is less than 128*128. + const int shard_offset = (i % shard_block_size) / VecSize; + const int shard_head = + (i / shard_block_size) * shard_block_size + i % VecSize * 32; + cos_ptr[shard_head + shard_offset] = static_cast(cos[token_id * cos_stride + i]); + sin_ptr[shard_head + shard_offset] = static_cast(sin[token_id * sin_stride + i]); + } +} + +template +__device__ void apply_k_rotary_emb_compute( + scalar_t* __restrict__ key, scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, const int64_t key_stride, + const int64_t value_stride, const int token_id, + const int block_table_stride, const int head_num, const int head_dim, + const int kv_head_num, const int block_size, const int half_head_dim, + const int shard_block_size) { + const int seq_len = sequence_lengths[token_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = + block_tables[token_id * block_table_stride + seq_len / block_size]; + + if (block_id < 0) { + return; + } + + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; + const int64_t target_id = block_id * head_num * head_dim * block_size + + (i / half_head_dim) * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(x, key + addr_offset); + copy_vector(y, key + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); + } + + copy_vector(key_cache + target_id, out_x); + copy_vector(key_cache + target_id + half_head_dim, + out_y); + } + + // apply value memcopy + apply_kv_memcopy( + value, value_cache, value_stride, token_id, block_id, head_num * head_dim, + block_size, block_offset, head_dim, half_head_dim); +} + +template +__global__ void rotary_embedding_and_cache_copy_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ value, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int64_t query_stride, + const int64_t key_stride, + const int64_t value_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int block_table_stride, + const int head_num, + const int head_dim, + const int kv_head_num, + const int block_size +) { + + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key and copy kv + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); +} + +template +__global__ void rotary_embedding_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + const int64_t query_stride, + const int64_t key_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int head_num, + const int head_dim, + const int kv_head_num +) { + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); +} + +template +void apply_rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + int block_size = key_cache.size(2); + + int64_t query_stride = query.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + int block_table_stride = block_tables.stride(0); + + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 2: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 4: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +void apply_rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + + int query_stride = query.stride(0); + int key_stride = key.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 2: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 4: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision) +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, + query.scalar_type(), + "rotary_embedding_and_cache_copy", + apply_rotary_embedding_and_cache_copy( + query, + key, + value, + cos, + sin, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} + +void rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin, // [total_tokens, head_dim] + bool high_precision +){ + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, + query.scalar_type(), + "rotary_embedding", + apply_rotary_embedding( + query, + key, + cos, + sin + );) +} diff --git a/colossalai/kernel/extensions/csrc/cuda/get_cos_and_sin_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/get_cos_and_sin_kernel.cu new file mode 100644 index 000000000000..15aea740e6f9 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/get_cos_and_sin_kernel.cu @@ -0,0 +1,215 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" +#include "stdio.h" + +template +__device__ void apply_cos_and_sin_memcopy( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int head_dim, + const int dest_offset_id, + const int src_offset_id + ) { + + int begin_id = threadIdx.x * VecSize; + + for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){ + copy_vector(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id); + copy_vector(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id); + } + + if (!Aligned) { + for (; begin_id < head_dim; ++begin_id ) { + cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id]; + sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id]; + } + } +} + +template +__global__ void apply_get_context_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cumsum_lengths, + const int batch_size, + const int head_dim +) { + int token_id = blockIdx.x; + if ( token_id >= sequence_lengths[blockIdx.y] ) { + return ; + } + + int src_offset_id = token_id * head_dim; + int dest_offset_id = src_offset_id; + + if (blockIdx.y > 0) { + dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim; + } + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); + +} + +template +__global__ void apply_get_decode_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int batch_size, + const int head_dim +) { + int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim; + int dest_offset_id = blockIdx.y * head_dim; + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); +} + +template +void apply_get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + int token_num = cos.size(0); + int head_dim = cos.size(1); + int batch_size = sequence_lengths.size(0); + + at::Tensor cumsum_lengths; + + int vec_size = get_vec_size(cos); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int block_size_y; + int block_size_x; + + if (is_prompts) { + block_size_y = batch_size; + block_size_x = max_seq_len_in_batch; + // TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on. + cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32); + } + else{ + block_size_y = batch_size; + block_size_x = 1; + } + + int thread_nums = (head_dim + vec_size - 1) / vec_size; + + dim3 grid(block_size_x, block_size_y); + dim3 block(std::min(thread_nums, 512)); + +#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + if (is_prompts){ \ + apply_get_context_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + cumsum_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + else { \ + apply_get_decode_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + } while(0) + +#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + cos.scalar_type(), + "get_cos_and_sin", + apply_get_cos_and_sin( + cos_cache, + sin_cache, + cos, + sin, + sequence_lengths, + max_seq_len_in_batch, + is_prompts + );) +} diff --git a/colossalai/kernel/extensions/csrc/cuda/include/block_reduce.h b/colossalai/kernel/extensions/csrc/cuda/include/block_reduce.h new file mode 100644 index 000000000000..6f6db6f774ab --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/include/block_reduce.h @@ -0,0 +1,184 @@ +#pragma once + +#include +#include +#include + +#include "../funcs/op_functor.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +const float kReduceFloatInfNeg = -100000000.f; +const float kReduceFloatInfPos = 100000000.f; +const int kWarpSize = 32; +const unsigned int kWarpReduceMask = 0xffffffff; + +enum class ReduceType { kMax = 0, kSum }; + +template +struct GetOpForReduceType; + +template +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; + +template +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; + +#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ + for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = \ + OP(*(VAL_PTR + offset), \ + __shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \ + } + +#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES) + +#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \ + DEFAULT_VALUE, REDUCE_TYPE) \ + __shared__ T shm[LANES][32]; \ + int lane_id = threadIdx.x & 0x1f; \ + int warp_id = threadIdx.x >> 5; \ + \ + warp_reduce(VAL_PTR); \ + if (lane_id == 0) { \ + for (int offset = 0; offset < LANES; ++offset) { \ + shm[offset][warp_id] = *(VAL_PTR + offset); \ + } \ + } \ + __syncthreads(); \ + \ + for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \ + ? shm[offset][lane_id] \ + : static_cast(DEFAULT_VALUE); \ + } \ + warp_reduce(VAL_PTR); + +template +__forceinline__ __device__ void warp_reduce(T* pval) { + typename GetOpForReduceType::Op op; + COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes); +} + +template +__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() { + if constexpr (rtype == ReduceType::kSum) { + return static_cast(0.0f); + } else if constexpr (rtype == ReduceType::kMax) { + return static_cast(kReduceFloatInfNeg); + } +} + +template +__forceinline__ __device__ void block_reduce(T* pval) { + constexpr T kDefaultValue = GetDefaultValueForBlockReduce(); + typename GetOpForReduceType::Op op; + COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue, + rtype); +} + +#undef COLOSSAL_SHFL_FUNCTION +#undef COLOSSAL_WARP_REDUCE_IMPL +#undef COLOSSAL_BLOCK_REDUCE_IMPL + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/layer_norm_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/layer_norm_kernel.cu new file mode 100644 index 000000000000..8239adc9f369 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/layer_norm_kernel.cu @@ -0,0 +1,683 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" +#include "../common/micros.h" + +template +__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template +__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, + U& mu, U& sigma2, U& count) { + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA * mu + nB * muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template +__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1, + const int n2, const int i1, U& mu, U& sigma2, + U* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu = U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1 * n2; + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l + k]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + U muB = WARP_SHFL(mu, srcLaneB); + U countB = WARP_SHFL(count, srcLaneB); + U sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2 * threadIdx.y]; + U sigma2B = ubuf[2 * threadIdx.y + 1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / U(n2), 0); + } + } +} + +template <> +__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals, + const int n1, const int n2, const int i1, + float& mu, float& sigma2, float* buf) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu = float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1 * n2; + int l = 8 * thrx; + if ((((size_t)lvals) & 3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l + 7 < n2; l += 8 * numx) { + for (int k = 0; k < 8; k += 2) { + float2 curr = __half22float2(*((__half2*)(lvals + l + k))); + cuWelfordOnlineSum(curr.x, mu, sigma2, count); + cuWelfordOnlineSum(curr.y, mu, sigma2, count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr, mu, sigma2, count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x + (1 << l)) & 31; + float muB = WARP_SHFL(mu, srcLaneB); + float countB = WARP_SHFL(count, srcLaneB); + float sigma2B = WARP_SHFL(sigma2, srcLaneB); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && + threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2 * wrt_y] = mu; + ubuf[2 * wrt_y + 1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2 * threadIdx.y]; + float sigma2B = ubuf[2 * threadIdx.y + 1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1] / float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2 / float(n2), 0); + } + } +} + +template +U rsqrt(U v) { + return U(1) / sqrt(v); +} +template <> +float rsqrt(float v) { + return rsqrtf(v); +} +template <> +double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ float* getPointer() { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} // namespace + +template +__global__ void cuApplyLayerNorm(V* __restrict__ output_vals, + U* __restrict__ mean, U* __restrict__ invvar, + const T* __restrict__ vals, const int n1, + const int n2, const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta) { + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu, sigma2; + cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); + const T* lvals = vals + i1 * n2; + V* ovals = output_vals + i1 * n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i += numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template +__device__ void cuLoadWriteStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = + curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template +__device__ void cuLoadAddStridedInputs( + const int i1_block, const int thr_load_row_off, const int thr_load_col_off, + const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2, + const T* input, const V* dout, const int i1_end, const int n2, + const U* __restrict__ mean, const U* __restrict__ invvar) { + int i1 = i1_block + thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1 * n2 + i2; + int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; + if (i2 < n2) { + U curr_input = static_cast(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += + curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template +__global__ void cuComputePartGradGammaBeta( + const V* __restrict__ dout, const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + U epsilon, U* part_grad_gamma, U* part_grad_beta) { + const int numsegs_n1 = + (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; + const int i1_beg_plus_one = + (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x + 1; + const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); + const int thr_load_row_off = + (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * + // blockDim.y + (blockDim.y - + // 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; + i1_block += blockDim.y * blockDim.y) { + cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, + row_stride, warp_buf1, warp_buf2, input, dout, + i1_end, n2, mean, invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k * blockDim.y; + int idx1 = row1 * row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; + warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1 * row_stride + threadIdx.x; + int idx2 = row2 * row_stride + threadIdx.x; + part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template +__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, const int n1, + const int n2, V* grad_gamma, + V* grad_beta) { + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = + part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = + part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; + ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; + sum_beta += part_grad_beta_ptr[warp_offset * n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx + nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx + nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template +__global__ void cuComputeGradInput(const V* __restrict__ dout, + const T* __restrict__ input, const int n1, + const int n2, const U* __restrict__ mean, + const U* __restrict__ invvar, U epsilon, + const V* gamma, T* grad_input) { + for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1 * n2; + const V* k_dout = dout + i1 * n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss * gamma[l + k]; + sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4 * thrx; + for (; l + 3 < n2; l += 4 * numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l + k]); + const U c_loss = static_cast(k_dout[l + k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2 * wrt_i] = sum_loss1; + buf[2 * wrt_i + 1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2 * read_i]; + sum_loss2 += buf[2 * read_i + 1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2 * threadIdx.x] = sum_loss1; + buf[2 * threadIdx.x + 1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y != 0) { + sum_loss1 = buf[2 * threadIdx.x]; + sum_loss2 = buf[2 * threadIdx.x + 1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1 * n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l += numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + +template +void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1, + int n2, double epsilon, const V* gamma, const V* beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32, 4, 1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; + cuApplyLayerNorm<<>>( + output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); +} + +void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, + at::Tensor* input, int n1, int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, double epsilon) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm(output->data_ptr(), + mean->data_ptr(), invvar->data_ptr(), + input->data_ptr(), n1, n2, epsilon, + gamma != NULL ? gamma->data_ptr() : NULL, + beta != NULL ? beta->data_ptr() : NULL);) +} + +template +void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, + at::Tensor* input, int n1, int n2, const V* gamma, + const V* beta, double epsilon, T* grad_input, + V* grad_gamma, V* grad_beta) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32, 4, 1); + const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = + 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size, n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr()); + + const dim3 threads3(32, 8, 1); + const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), part_size, + n1, n2, grad_gamma, grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32, 4, 1); + int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; + cuComputeGradInput<<>>( + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), gamma, + grad_input); +} + +void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, + at::Tensor* invvar, at::Tensor* input, int n1, + int n2, +#ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, +#else + at::IntList normalized_shape, +#endif + at::Tensor* gamma, at::Tensor* beta, + double epsilon, at::Tensor* grad_input, + at::Tensor* grad_gamma, at::Tensor* grad_beta) { + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->data_ptr(), mean->data_ptr(), + invvar->data_ptr(), input, n1, n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->data_ptr() : NULL, + gamma != NULL ? beta->data_ptr() : NULL, epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL, + gamma != NULL ? grad_beta->data_ptr() : NULL);) +} diff --git a/colossalai/kernel/extensions/csrc/cuda/moe_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/moe_kernel.cu new file mode 100644 index 000000000000..7b28dffe91a3 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/moe_kernel.cu @@ -0,0 +1,662 @@ +#include +#include +#include + +#include + +#include "block_reduce.h" + + +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::ReduceType; + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + block_reduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + block_reduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data_ptr(), res.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data_ptr(), expert_grad.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data_ptr(), res.data_ptr(), + logits.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data_ptr(), egrad.data_ptr(), + expert_tokens.data_ptr(), logits.data_ptr(), + wgrad.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data_ptr(), res.data_ptr(), s, e); + + return res; +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_adam_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_adam_kernel.cu new file mode 100644 index 000000000000..b7793b364f7a --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_adam_kernel.cu @@ -0,0 +1,146 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "../common/micros.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta1_correction, + const float beta2_correction, const float epsilon, const float lr, + adamMode_t mode, const float decay, const float div_scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_g *g = (T_g *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T_p *p = (T_p *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T_p *m = (T_p *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T_p *v = (T_p *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (div_scale > 0) r_g[ii] /= div_scale; + + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + DISPATCH_FLOAT_AND_HALF_FOR_G_P( + tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay, div_scale);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_apply.cuh b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_apply.cuh new file mode 100644 index 000000000000..799ccfa73637 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_apply.cuh @@ -0,0 +1,130 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 + Licensed under the MIT License. +*/ +#include +#include +#include +#include +#include +#include + +#include "../common/micros.h" + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void *addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a + // full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel(int chunk_size, + volatile int *noop_flag, T tl, + U callable, ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply( + int block_size, int chunk_size, const at::Tensor &noop_flag, + const std::vector> &tensor_lists, T callable, + ArgTypes... args) { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); + l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, + "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = + (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), + "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.data_ptr(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 + // << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 + // << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu new file mode 100644 index 000000000000..fe86a8104dd1 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -0,0 +1,387 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu +#include +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "../common/micros.h" +#include "include/block_reduce.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::reduce_block_into_lanes; +using colossalAI::cuda::utils::reduce_block_into_lanes_max_op; + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct L2NormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +// Probably better to template, but since we are not likely to support other +// norm +template +struct MaxNormFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be + // sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); + + float final = reduce_block_into_lanes_max_op(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +__global__ void cleanup(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) *ret = sqrt(final); + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); + } +} + +__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor, int norm_type, + float alpha, float beta) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + if (norm_type == 0) { + float final = reduce_block_into_lanes_max_op(vals, val); + if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; + } else { + float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); + } + } + + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + if (norm_type == 0) { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); + + float final = reduce_block_into_lanes_max_op(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = + alpha * ret_per_tensor[blockIdx.x] + beta * final; + } else { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * + ret_per_tensor[blockIdx.x] + + beta * final); + } + } +} + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + bool per_tensor = + per_tensor_python.has_value() ? per_tensor_python.value() : false; + + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, + per_tensor, max_chunks_per_tensor);) + + AT_CUDA_CHECK(cudaGetLastError()); + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup<<>>( + output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, + ret.data_ptr(), + per_tensor ? ret_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor); + + return std::tuple(ret, ret_per_tensor); +} + +// Compute and update grad norm +// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by +// L-2: gn = sqrt(a * gn^2 + b * n^2) +// L-inf: gn = a * gn + b * n +void multi_tensor_norm_out_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor out, + const float alpha, const float beta, const int norm_type) { + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), + "noop flag should be on the same device as tensors"); + // we don't need global thus uses empty here + auto output = at::empty({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + + // Although it is single write then read, still need to be zero + // Since tailing element also participate cleanup + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + + if (norm_type == 0) { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + MaxNormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) + } else { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) + } + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + + // Adding the following device guard since it happens sometimes that the + // tensors are on one device and the cuda stream is on another device which + // results in ILLEGAL MEM ACCESS error. + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup_v2<<>>( + output.data_ptr(), output_per_tensor.data_ptr(), + ret.data_ptr(), out.data_ptr(), true, max_chunks_per_tensor, + norm_type, alpha, beta); + + return; +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu new file mode 100644 index 000000000000..82c02f36d80f --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu @@ -0,0 +1,354 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "../common/micros.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +typedef enum { + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode +} adamMode_t; + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +using MATH_T = float; + +template +struct LAMBStage1Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float clipped_global_grad_norm = + (*global_grad_norm) > max_global_grad_norm + ? (*global_grad_norm) / max_global_grad_norm + : 1.0f; + + T *g = (T *)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T *m = (T *)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + T *v = (T *)tl.addresses[3][tensor_loc]; + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) && + is_aligned(p) && is_aligned(m) && is_aligned(v)) { + T l_g[ILP]; + T l_p[ILP]; + T l_m[ILP]; + T l_v[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(l_g, g, 0, i_start); + if (decay != 0) load_store(l_p, p, 0, i_start); + load_store(l_m, m, 0, i_start); + load_store(l_v, v, 0, i_start); + // unpack +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_g[ii] = l_g[ii]; + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = l_p[ii]; + } + r_m[ii] = l_m[ii]; + r_v[ii] = l_v[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + l_p[ii] = r_p[ii]; + l_m[ii] = r_m[ii]; + l_v[ii] = r_v[ii]; + } + // store + load_store(g, l_p, i_start, 0); + load_store(m, l_m, i_start, 0); + load_store(v, l_v, i_start, 0); + } + } else { + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + // special ?optimization? for lamb stage 1 + if (decay == 0) { + r_p[ii] = MATH_T(0); + } else { + r_p[ii] = p[i]; + } + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == MOMENT_MODE_0) { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + // L2 on scaled grad + scaled_grad = scaled_grad + decay * r_p[ii]; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = next_m_unbiased / denom; + } else { + MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; + r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; + r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + g[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } + } + } +}; + +// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. +// It computes new parameter value. +template +struct LAMBStage2Functor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, const float *per_tensor_update_norm, + const float learning_rate, const float decay, bool use_nvlamb) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int tensor_num = tl.start_tensor_this_launch + tensor_loc; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + MATH_T ratio = learning_rate; + // nvlamb: apply adaptive learning rate to all parameters + // otherwise, only apply to those with non-zero weight decay + if (use_nvlamb || (decay != 0.0)) { + float param_norm = per_tensor_param_norm[tensor_num]; + float update_norm = per_tensor_update_norm[tensor_num]; + ratio = (update_norm != 0.0f && param_norm != 0.0f) + ? learning_rate * (param_norm / update_norm) + : learning_rate; + } + + T *update = (T *)tl.addresses[0][tensor_loc]; + update += chunk_idx * chunk_size; + + T *p = (T *)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && + is_aligned(update)) { + T r_p[ILP]; + T r_update[ILP]; + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_p, p, 0, i_start); + load_store(r_update, update, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = static_cast(r_p[ii]) - + (ratio * static_cast(r_update[ii])); + } + load_store(p, r_p, i_start, 0); + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { + MATH_T r_p[ILP]; + MATH_T r_update[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_p[ii] = p[i]; + r_update[ii] = update[i]; + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_p[ii] = r_p[ii] - (ratio * r_update[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + } + } + } + } + } +}; + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python) { + using namespace at; + // Master weight and 32bit momentum(potentially changing) is not handled by + // this So we assume every tensor are all in the same type + + bool use_nvlamb = + use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + // Handle grad averaging mode + float beta3 = 1.0f; + if (grad_averaging == 1) beta3 = 1 - beta1; + + std::vector> grad_list(tensor_lists.begin(), + tensor_lists.begin() + 1); + std::vector> param_list(tensor_lists.begin() + 1, + tensor_lists.begin() + 2); + + // Compute per tensor param norm + auto param_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); + + // We now in-place modify grad to store update before compute its norm + // Generally this is not a issue since people modify grad in step() method all + // the time We can also grab list of empty tensor to avoid this, but I'd like + // to save space/cpu code + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + LAMBStage1Functor(), beta1, beta2, + beta3, // 1-beta1 or 1 depends on averaging mode + bias_correction1, bias_correction2, epsilon, + (adamMode_t)mode, weight_decay, + global_grad_norm.data_ptr(), max_grad_norm);) + + // Compute update norms + auto update_norm_tuple = + multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); + + std::vector> grad_param_list( + tensor_lists.begin(), tensor_lists.begin() + 2); + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, + LAMBStage2Functor(), + std::get<1>(param_norm_tuple).data_ptr(), + std::get<1>(update_norm_tuple).data_ptr(), + lr, weight_decay, use_nvlamb);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_scale_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_scale_kernel.cu new file mode 100644 index 000000000000..0dec1d5d1445 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_scale_kernel.cu @@ -0,0 +1,125 @@ +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "multi_tensor_apply.cuh" +#include "../common/micros.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +} + +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int *noop_gmem, + TensorListMetadata<2> &tl, + float scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = (in_t *)tl.addresses[0][tensor_loc]; + in += chunk_idx * chunk_size; + + out_t *out = (out_t *)tl.addresses[1][tensor_loc]; + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && + is_aligned(out)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point + // unrolling the write loop, since writes just fire off once their LDGs + // arrive. Put another way, the STGs are dependent on the LDGs, but not + // on each other. There is still compute ILP benefit from unrolling the + // loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + } +}; + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale) { + using namespace at; + // The output (downscaled) type is always float. + // If build times suffer, think about where to put this dispatch, + // and what logic should be moved out of multi_tensor_apply. + + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_AND_HALF( + tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScaleFunctor(), + scale);)) + AT_CUDA_CHECK(cudaGetLastError()); + + // AT_CUDA_CHECK(cudaDeviceSynchronize()); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu new file mode 100644 index 000000000000..d0cf786f8e6f --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu @@ -0,0 +1,167 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu +#include +#include +#include +#include +#include +#include + +#include "../common/micros.h" +#include "multi_tensor_apply.cuh" + +#define BLOCK_SIZE 512 +#define ILP 4 + +/** + * Perform fused SGD on multiple buffers + * N: number of tensors + * tl[0] : gradients + * tl[1] : weights + * tl[2] : momentum buffers + * tl[3] : fp16 weights (if appropriate) + * wd : weight_decay (scalar) + * momentum : momentum (scalar) + * dampening : momentum dampening (scalar) + * lr : learning rate (scalar) + * nesterov : enable nesterov (bool) + * first run : necessary for proper momentum handling & init + * wd_after_momentum : apply weight decay _after_ momentum instead of before + **/ +template +struct SGDFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, + float wd, float momentum, float dampening, float lr, bool nesterov, + bool first_run, bool wd_after_momentum, float scale) { + // Early exit if we don't need to do anything + if (*noop_gmem) return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; + + T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; + + T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } + +// note for clarification to future michael: +// From a pure memory dependency perspective, there's likely no point unrolling +// the write loop, since writes just fire off once their LDGs arrive. +// Put another way, the STGs are dependent on the LDGs, but not on each other. +// There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + if (momentum != 0.f) { + if (!first_run) + incoming_moms[ii] = incoming_moms[ii] * momentum + + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; + + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } + + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; + + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); + + // also write out the new momentum + if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; + } + } + } + } +}; + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale) { + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); + + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), + "expected noop flag to be on the same device as tensors"); + + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type + // 1. fp16, fp16, fp16 + // 2. fp32, fp32, fp32 + // 3. fp16, fp32, fp32 + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. + + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 2. fp32, fp32, fp32 + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, dampening, + lr, nesterov, first_run, wd_after_momentum, scale); + } + // Case 3. fp16, fp32, fp32 + else if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } else { + AT_ERROR( + "multi_tensor_sgd only supports some combinations of gradient & weight " + "types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, + ", num_lists: ", num_tensors); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/inference.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/inference.cpp new file mode 100644 index 000000000000..6a468fcb814a --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/inference.cpp @@ -0,0 +1,84 @@ +#include + +void decode_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables); // [batch_size, max_seq_len] + +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch); + +void rotary_embedding( + torch::Tensor& query, // [total_tokens, head_num, head_dim] + torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] + torch::Tensor& cos, // [total_tokens, head_dim] + torch::Tensor& sin, // [total_tokens, head_dim] + bool high_precision); + +void rotary_embedding_and_cache_copy( + torch::Tensor& query, // [num_tokens, head_num, head_dim] + torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] + torch::Tensor& value, // [num_tokens, num_heads, head_dim] + torch::Tensor& cos, // [num_tokens, head_dim] + torch::Tensor& sin, // [num_tokens, head_dim] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision); + +torch::Tensor silu_and_mul(const torch::Tensor& ins); + +void rms_layernorm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, bool is_prompts); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the decode stage."); + + m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the context stage."); + + m.def( + "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, + "Performing Rotary Embedding-related calculations and KVCache Memcopy."); + + m.def("rotary_embedding", &rotary_embedding, + "Performing Rotary Embedding-related calculations."); + + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); + + m.def("rms_layernorm", &rms_layernorm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, + "In-place fused Add and RMS Normalization."); + + m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache."); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/layer_norm.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/layer_norm.cpp new file mode 100644 index 000000000000..b1f7c254349e --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/layer_norm.cpp @@ -0,0 +1,141 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include + +#include +#include + +#include "../../common/micros.h" + +namespace { + +void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert(input.sizes()[i + idiff] == normalized_shape[i]); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, + at::Tensor beta) { + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, + int &n2) { + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input, normalized_shape, n1, n2); +} + +void check_args(at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, int &n1, int &n2) { + check_args(input, normalized_shape, n1, n2); + check_args(normalized_shape, gamma, beta); +} +} // namespace + +void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, + at::Tensor *input, int n1, int n2, + at::IntArrayRef normalized_shape, at::Tensor *gamma, + at::Tensor *beta, double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine(at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = + at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = + at::empty({n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, + &gamma, &beta, epsilon); + + return {output, mean, invvar}; +} + +void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, + at::Tensor *invvar, at::Tensor *input, int n1, + int n2, at::IntArrayRef normalized_shape, + at::Tensor *gamma, at::Tensor *beta, + double epsilon, at::Tensor *grad_input, + at::Tensor *grad_gamma, at::Tensor *grad_beta); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, + at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, + double epsilon) { + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/moe.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/moe.cpp new file mode 100644 index 000000000000..8c0b89eb06d1 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/moe.cpp @@ -0,0 +1,97 @@ +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/optimizer.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/optimizer.cpp new file mode 100644 index 000000000000..94f132521771 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/optimizer.cpp @@ -0,0 +1,49 @@ +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu +#include + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float scale); + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale); + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int mode, + const int bias_correction, const float weight_decay, + const float div_scale); + +void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, + const float beta2, const float epsilon, + const int step, const int bias_correction, + const float weight_decay, const int grad_averaging, + const int mode, at::Tensor global_grad_norm, + const float max_grad_norm, + at::optional use_nvlamb_python); + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multi_tensor_scale", &multi_tensor_scale_cuda, + "Fused overflow check + scale for a list of contiguous tensors"); + m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, + "Fused SGD optimizer for list of contiguous tensors"); + m.def("multi_tensor_adam", &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); + m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, + "Computes and apply update for LAMB optimizer"); + m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, + "Computes L2 norm for a list of contiguous tensors"); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp new file mode 100644 index 000000000000..8c2982b0cff9 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp @@ -0,0 +1,70 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include + +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp b/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 000000000000..cbbc3706497a --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,54 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include + +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/colossalai/kernel/extensions/csrc/cuda/rms_layernorm_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/rms_layernorm_kernel.cu new file mode 100644 index 000000000000..c39e44d8725f --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -0,0 +1,426 @@ +/*This code from VLLM: + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu + * with minor changes. */ + +#include +#include +#include +#include + + +#include "block_reduce.h" +#include "../common/micros.h" +#include "funcs/cast_functor.h" +#include "funcs/op_functor.h" + +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::TypeConverter; +using colossalAI::cuda::funcs::CastFunctor; +using colossalAI::cuda::funcs::BinaryOpFunctor; +using colossalAI::cuda::funcs::BinaryOpType; + +#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ + if (DATA_SIZE == 2) { \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } else { \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + general_##__VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } \ + +// optimized for half and bf16 +template +__global__ void rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; + BinaryOpFunctor mul_scalar2t; + __shared__ float s_variance; + + /* + * since the open-sourced LLM's hidden dimensions mainly range from + * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported + * hidden dimension limit to 8192, and each thread's capacity + * for caching input tensors to 8 (8192 = 8 * 1024) which + * will cause problems for extremely large models, such as + * Megatron-Turing NLG 530B with hidden dimensions up to 20480 + */ + scalar2_t x_local[4]; + + scalar2_t* out_ptr = (scalar2_t*)out; + const scalar2_t* input_ptr = (scalar2_t*)input; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = CastFunctor()(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); + } +} + +template +__global__ void general_rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = (float) input[id]; + variance += x_local[cnt] * x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +// optimized for half and bf16 +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; + BinaryOpFunctor add_scalar2t; + BinaryOpFunctor mul_scalar2t; + + __shared__ float s_variance; + scalar2_t x_local[4]; + + scalar2_t* input_ptr = (scalar2_t*)input; + scalar2_t* residual_ptr = (scalar2_t*)residual; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]); + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + residual_ptr[id] = x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = CastFunctor()(s_variance); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); + } +} + +template +__global__ void general_fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = (float) input[id]; + x_local[cnt] += (float) residual[id]; + variance += x_local[cnt] * x_local[cnt]; + residual[id] = (scalar_t) x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +void rms_layernorm( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } +} + +void fused_add_rms_layernorm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } +} diff --git a/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax.h b/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax.h new file mode 100644 index 000000000000..cbbe7f36ad38 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax.h @@ -0,0 +1,500 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include "utils/vector_copy_utils.h" + +namespace { + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu new file mode 100644 index 000000000000..2f968d30f106 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu @@ -0,0 +1,89 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include + +#include "scaled_masked_softmax.h" +#include "../common/micros.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = torch::empty( + {batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h new file mode 100644 index 000000000000..bd2465beabd2 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,538 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "utils/vector_copy_utils.h" + +namespace { + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu b/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu new file mode 100644 index 000000000000..d9550dc2c2a5 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu @@ -0,0 +1,75 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include + +#include "scaled_upper_triang_masked_softmax.h" +#include "../common/micros.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), scale_factor, seq_len, + seq_len, attn_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, + // seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, seq_len, seq_len, attn_batches);); + + // backward pass is completely in-place + return output_grads; +} +} // namespace scaled_upper_triang_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/colossalai/kernel/extensions/csrc/cuda/utils/gpu_launch_config.h b/colossalai/kernel/extensions/csrc/cuda/utils/gpu_launch_config.h new file mode 100644 index 000000000000..b953c6587a64 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/utils/gpu_launch_config.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include "nvgpu_dev_info.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +struct GPULaunchConfig { + dim3 block{1, 1, 1}; + dim3 grid{1, 1, 1}; +}; + +static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info, + int64_t numel, int64_t vec_size) { + const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock(); + const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0]; + const int64_t kMinimumSize = 64; + const int64_t kMaximumSize = 512; + int64_t active_threads = (numel + vec_size - 1) / vec_size; + int64_t sm_num = dev_info.GetMultiProcessorCount(); + + // Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally + int64_t expected_threads_per_block = kMaximumSize; + + auto RoundUpToPowerOfTwo = [](int64_t x) { + bool is_power_of_two = false; + int64_t ret = 1; + int64_t y = x; + while (y > 0) { + is_power_of_two = ((ret ^ x) == 0); + y = (x >> 1); + ret = (ret << 1); + if (y > 0) is_power_of_two = false; + } + if (is_power_of_two) return x; + return ret; + }; + + if ((active_threads / (sm_num << 1)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 1)); + } else if ((active_threads / (sm_num << 2)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 2)); + } + + expected_threads_per_block = + std::max(expected_threads_per_block, kMinimumSize); + int64_t expect_block_per_grid = + ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + + if (expect_block_per_grid > max_blocks_per_grid) { + expect_block_per_grid = max_blocks_per_grid; + expected_threads_per_block = + (active_threads + expect_block_per_grid - 1) / expect_block_per_grid; + if (expected_threads_per_block > max_threads_per_block) + throw std::invalid_argument( + "Threads required for current input exceed for current GPU!"); + expected_threads_per_block = + RoundUpToPowerOfTwo(expected_threads_per_block); + expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + } + + GPULaunchConfig config; + config.block.x = expected_threads_per_block; + config.grid.x = expect_block_per_grid; + return config; +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/utils/micros.h b/colossalai/kernel/extensions/csrc/cuda/utils/micros.h new file mode 100644 index 000000000000..aaa2fc1ef1b9 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/utils/micros.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(status)); \ + } \ + } + +#define HOST __host__ +#define DEVICE __device__ +#define HOSTDEVICE __host__ __device__ diff --git a/colossalai/kernel/extensions/csrc/cuda/utils/nvgpu_dev_info.h b/colossalai/kernel/extensions/csrc/cuda/utils/nvgpu_dev_info.h new file mode 100644 index 000000000000..f4c017e754c3 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/utils/nvgpu_dev_info.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "micros.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +class NVGPUDevInfo { + public: + explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { + CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num)); + } + + std::array GetMaxGridDims() const { + std::array ret; + ret[0] = prop_.maxGridSize[0]; + ret[1] = prop_.maxGridSize[1]; + ret[2] = prop_.maxGridSize[2]; + return ret; + } + + std::array GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_.maxThreadsDim[0]; + ret[1] = prop_.maxThreadsDim[1]; + ret[2] = prop_.maxThreadsDim[2]; + return ret; + } + + std::array GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; + return ret; + } + + int GetMultiProcessorCount() const { return prop_.multiProcessorCount; } + + int GetMaxThreadsPerMultiProcessor() const { + return prop_.maxThreadsPerMultiProcessor; + } + + int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; } + + private: + int device_num_; + cudaDeviceProp prop_; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/utils/vec_type_traits.h b/colossalai/kernel/extensions/csrc/cuda/utils/vec_type_traits.h new file mode 100644 index 000000000000..3ddd64df95fd --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/utils/vec_type_traits.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include + +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +template +struct VecTypeTrait {}; + +template +struct VecTypeTrait { + using Type = T; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = half; +}; + +template <> +struct VecTypeTrait { + using Type = half2; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/colossalai/kernel/extensions/csrc/cuda/utils/vector_copy_utils.h b/colossalai/kernel/extensions/csrc/cuda/utils/vector_copy_utils.h new file mode 100644 index 000000000000..5157ec738ca1 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -0,0 +1,52 @@ + +#pragma once + +#include +#include +#include + +#include "vec_type_traits.h" + +template +__device__ __inline__ void copy_vector(T *dst, const T *src) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + // Note(LiuYang): Here static_cast can't be used for cast between two pointer + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst + 4)) = + *(reinterpret_cast(src + 4)); +} + +template +__device__ __inline__ void copy_zero_vector(T *dst) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = {0.0}; +} + +template +int get_vec_size(const torch::Tensor &tensor) { + uint64_t address = reinterpret_cast(tensor.data_ptr()); + const int max_aligned_size = 128; + const int dtype_size = sizeof(T) * 8; + + const int vec_size = max_aligned_size / sizeof(T) / 8; + + // Note(LiuYang): Performance of situation of which + // vec_size equals to 8 need to be profiled in the future + // if (address % (dtype_size * 8) == 0) { + // return std::min(8, vec_size); + // } + if (address % (dtype_size * 4) == 0) { + return std::min(4, vec_size); + } else if (address % (dtype_size * 2) == 0) { + return std::min(2, vec_size); + } else { + return 1; + } +} diff --git a/colossalai/kernel/extensions/csrc/scaled_softmax.py b/colossalai/kernel/extensions/csrc/scaled_softmax.py new file mode 100644 index 000000000000..7c220d60dd19 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/scaled_softmax.py @@ -0,0 +1,190 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader + +try: + from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax +except ImportError: + scaled_masked_softmax = None + scaled_upper_triang_masked_softmax = None + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type.value > 1: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type.value > 1: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/extensions/csrc/x86/cpu_adam.cpp b/colossalai/kernel/extensions/csrc/x86/cpu_adam.cpp new file mode 100644 index 000000000000..be9300c545c2 --- /dev/null +++ b/colossalai/kernel/extensions/csrc/x86/cpu_adam.cpp @@ -0,0 +1,446 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#include "cpu_adam.h" + +#include +#include +#include + +#include +#include +#include +#include + +// C++ interface + +void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data grad_4; + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); + } + AVX_Data momentum_4; + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + + AVX_Data variance_4; + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); + + AVX_Data param_4; + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); + } + momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); + momentum_4.data = + SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); + variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); + grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); + variance_4.data = + SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); + grad_4.data = SIMD_SQRT(variance_4.data); + grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = + SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data); + } + param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); + + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k]; + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = + param_half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + if (param_half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); + + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + AVX_Data grad_4[4]; + AVX_Data momentum_4[4]; + AVX_Data variance_4[4]; + AVX_Data param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + } + } + } +#endif + if (_param_size > rounded_size) + Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); +} + +void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, + float *_exp_avg_sq, size_t _param_size, + bool param_half_precision, bool grad_half_precision, + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = + (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + AVX_Data grad_4[8]; + AVX_Data momentum_4[8]; + AVX_Data variance_4[8]; + AVX_Data param_4[8]; +#pragma unroll 8 + for (int j = 0; j < 8; j++) { + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = + SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = + SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = + SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = + SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + } + } + } +#endif + if (_param_size > rounded_size) + Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size) + : _params + rounded_size), + (grad_half_precision ? (float *)(grads_cast_h + rounded_size) + : grads + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), + (_param_size - rounded_size), param_half_precision, + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); +} + +void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float *params_ptr = (float *)params_c.data_ptr(); + float *grads_ptr = (float *)grads_c.data_ptr(); + float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); + float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, + params_c.numel(), (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &Adam_Optimizer::step); +} diff --git a/colossalai/kernel/extensions/csrc/x86/cpu_adam.h b/colossalai/kernel/extensions/csrc/x86/cpu_adam.h new file mode 100644 index 000000000000..db1f26d5f6da --- /dev/null +++ b/colossalai/kernel/extensions/csrc/x86/cpu_adam.h @@ -0,0 +1,185 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + +#if defined(__AVX512__) +#define SIMD_WIDTH 16 +#define INTV __m256i +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_LOAD_HALF(x) \ + _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) + +#elif defined(__AVX256__) or defined(__AVX2__) +#define SIMD_WIDTH 8 +#define INTV __m128i +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) + +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) or defined(__AVX2__) + __m256 data; +#endif + // float data_f[16]; +}; + +#endif + +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); + +class Adam_Optimizer { + public: + Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~Adam_Optimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } +#endif + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); + + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; +}; diff --git a/colossalai/kernel/extensions/cuda_extension.py b/colossalai/kernel/extensions/cuda_extension.py new file mode 100644 index 000000000000..f1e0095b29b6 --- /dev/null +++ b/colossalai/kernel/extensions/cuda_extension.py @@ -0,0 +1,109 @@ +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension +from .cpp_extension import _CppExtension +from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list + +__all__ = ["_CudaExtension"] + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 + + +class _CudaExtension(_CppExtension): + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME + + if not CUDA_HOME: + raise AssertionError( + "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions" + ) + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def build_jit(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME, load + + set_cuda_arch_list(CUDA_HOME) + + # get build dir + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + def build_aot(self) -> "CUDAExtension": + from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension + + set_cuda_arch_list(CUDA_HOME) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py new file mode 100644 index 000000000000..ea5b442aa58d --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -0,0 +1,14 @@ +from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension +from .flash_attention_npu import FlashAttentionNpuExtension +from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension + +try: + # TODO: remove this after updating openmoe example + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False + + +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"] diff --git a/colossalai/kernel/extensions/flash_attention/flash_attention_dao_cuda.py b/colossalai/kernel/extensions/flash_attention/flash_attention_dao_cuda.py new file mode 100644 index 000000000000..a2f2a52f1af4 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/flash_attention_dao_cuda.py @@ -0,0 +1,96 @@ +from ..base_extension import _Extension + + +class FlashAttentionDaoCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa + from flash_attn.bert_padding import index_first_axis, pad_input # noqa + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + + def load(self): + from typing import Optional + + import torch + from einops import rearrange + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.bert_padding import index_first_axis, pad_input + + def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + # [B, N, S, D] -> [B, S, N, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + b, s_q = q.shape[:2] + if cu_seqlens_q is not None: + # padded / padded causal + # unpad input: [B, S, N, D] -> [T, N, D] + q = _unpad_input(q, q_indices) + kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + ) + # pad output: [T, N, D] -> [B, S, N, D] + attn_output = pad_input(attn_output, q_indices, b, s_q) + else: + # causal / no attn mask + attn_output = flash_attn_func( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + ) + # [B, S, N, D] -> [B, N, S, D] + return attn_output.transpose(1, 2) + + return flash_attention diff --git a/colossalai/kernel/extensions/flash_attention/flash_attention_npu.py b/colossalai/kernel/extensions/flash_attention/flash_attention_npu.py new file mode 100644 index 000000000000..0e01cefa1112 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/flash_attention_npu.py @@ -0,0 +1,62 @@ +from ..base_extension import _Extension + + +class FlashAttentionNpuExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) + + def is_available(self) -> bool: + try: + import torch_npu + + return hasattr(torch_npu, "npu_fusion_attention") + except: + return False + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu." + ) + + def load(self): + from typing import Optional + + import torch + import torch_npu + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + num_heads = q.size(1) + return torch_npu.npu_fusion_attention( + q, + k, + v, + num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=scale, + keep_prob=1 - dropout_p, + )[0] + + return flash_attention diff --git a/colossalai/kernel/extensions/flash_attention/flash_attention_sdpa_cuda.py b/colossalai/kernel/extensions/flash_attention/flash_attention_sdpa_cuda.py new file mode 100644 index 000000000000..d3323a6aae27 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/flash_attention_sdpa_cuda.py @@ -0,0 +1,56 @@ +from ..base_extension import _Extension + + +class FlashAttentionSdpaCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False) + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.") + + def build_jit(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.") + + def load(self): + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + return torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout_p, + scale=scale, + ) + + return flash_attention diff --git a/colossalai/kernel/extensions/inference/__init__.py b/colossalai/kernel/extensions/inference/__init__.py new file mode 100644 index 000000000000..c5ea424fa25d --- /dev/null +++ b/colossalai/kernel/extensions/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference_ops_cuda import InferenceOpsCudaExtension + +__all__ = ["InferenceOpsCudaExtension"] diff --git a/colossalai/kernel/extensions/inference/inference_ops_cuda.py b/colossalai/kernel/extensions/inference/inference_ops_cuda.py new file mode 100644 index 000000000000..09ebfdabde88 --- /dev/null +++ b/colossalai/kernel/extensions/inference/inference_ops_cuda.py @@ -0,0 +1,35 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class InferenceOpsCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="inference_ops_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/pybind/inference.cpp", + "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/context_kv_cache_memcpy_kernel.cu", + "cuda/fused_rotary_emb_and_cache_kernel.cu", + "cuda/activation_kernel.cu", + "cuda/rms_layernorm_kernel.cu", + "cuda/get_cos_and_sin_kernel.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/colossalai/kernel/extensions/layernorm/__init__.py b/colossalai/kernel/extensions/layernorm/__init__.py new file mode 100644 index 000000000000..30e6c68eff89 --- /dev/null +++ b/colossalai/kernel/extensions/layernorm/__init__.py @@ -0,0 +1,3 @@ +from .layernorm_cuda import LayerNormCudaExtension + +__all__ = ["LayerNormCudaExtension"] diff --git a/colossalai/kernel/extensions/layernorm/layernorm_cuda.py b/colossalai/kernel/extensions/layernorm/layernorm_cuda.py new file mode 100644 index 000000000000..36cf73590a3c --- /dev/null +++ b/colossalai/kernel/extensions/layernorm/layernorm_cuda.py @@ -0,0 +1,24 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-maxrregcount=50"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/moe/__init__.py b/colossalai/kernel/extensions/moe/__init__.py new file mode 100644 index 000000000000..3b6aa24bf7f6 --- /dev/null +++ b/colossalai/kernel/extensions/moe/__init__.py @@ -0,0 +1,3 @@ +from .moe_cuda import MoeCudaExtension + +__all__ = ["MoeCudaExtension"] diff --git a/colossalai/kernel/extensions/moe/moe_cuda.py b/colossalai/kernel/extensions/moe/moe_cuda.py new file mode 100644 index 000000000000..7a4744d4dc42 --- /dev/null +++ b/colossalai/kernel/extensions/moe/moe_cuda.py @@ -0,0 +1,29 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class MoeCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="moe_cuda") + + def include_dirs(self): + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/moe.cpp", "cuda/moe_kernel.cu"]] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/optimizer/__init__.py b/colossalai/kernel/extensions/optimizer/__init__.py new file mode 100644 index 000000000000..6a0c8d7b8016 --- /dev/null +++ b/colossalai/kernel/extensions/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .fused_optimizer_cuda import FusedOptimizerCudaExtension + +__all__ = ["FusedOptimizerCudaExtension"] diff --git a/colossalai/kernel/extensions/optimizer/fused_optimizer_cuda.py b/colossalai/kernel/extensions/optimizer/fused_optimizer_cuda.py new file mode 100644 index 000000000000..41c6260aa30d --- /dev/null +++ b/colossalai/kernel/extensions/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/pybind/optimizer.cpp", + "cuda/multi_tensor_sgd_kernel.cu", + "cuda/multi_tensor_scale_kernel.cu", + "cuda/multi_tensor_adam_kernel.cu", + "cuda/multi_tensor_l2norm_kernel.cu", + "cuda/multi_tensor_lamb_kernel.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/colossalai/kernel/extensions/softmax/__init__.py b/colossalai/kernel/extensions/softmax/__init__.py new file mode 100644 index 000000000000..8833d93e73d0 --- /dev/null +++ b/colossalai/kernel/extensions/softmax/__init__.py @@ -0,0 +1,4 @@ +from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension +from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension + +__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"] diff --git a/colossalai/kernel/extensions/softmax/scaled_masked_softmax_cuda.py b/colossalai/kernel/extensions/softmax/scaled_masked_softmax_cuda.py new file mode 100644 index 000000000000..797638c3b132 --- /dev/null +++ b/colossalai/kernel/extensions/softmax/scaled_masked_softmax_cuda.py @@ -0,0 +1,32 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads + + +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_masked_softmax_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"] + ] + return ret + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/colossalai/kernel/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 000000000000..d48d542ade3a --- /dev/null +++ b/colossalai/kernel/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/pybind/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_kernel.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/triton_extension.py b/colossalai/kernel/extensions/triton_extension.py new file mode 100644 index 000000000000..9f0792f8ce68 --- /dev/null +++ b/colossalai/kernel/extensions/triton_extension.py @@ -0,0 +1,21 @@ +from .base_extension import _Extension + +__all__ = ["_TritonExtension"] + + +class _TritonExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=False, support_jit=True, priority=priority) + + def is_hardware_compatible(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def load(self): + return self.build_jit() diff --git a/colossalai/kernel/extensions/utils.py b/colossalai/kernel/extensions/utils.py new file mode 100644 index 000000000000..d5d87a77a9c0 --- /dev/null +++ b/colossalai/kernel/extensions/utils.py @@ -0,0 +1,229 @@ +import os +import re +import subprocess +import warnings +from typing import List + + +def print_rank_0(message: str) -> None: + """ + Print on only one process to avoid spamming. + """ + try: + import torch.distributed as dist + + if not dist.is_initialized(): + is_main_rank = True + else: + is_main_rank = dist.get_rank() == 0 + except ImportError: + is_main_rank = True + + if is_main_rank: + print(message) + + +def get_cuda_version_in_pytorch() -> List[int]: + """ + This function returns the CUDA version in the PyTorch build. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + import torch + + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + except: + raise ValueError( + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) + return torch_cuda_major, torch_cuda_minor + + +def get_cuda_bare_metal_version(cuda_dir) -> List[int]: + """ + Get the System CUDA version from nvcc. + + Args: + cuda_dir (str): the directory for CUDA Toolkit. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") + + if cuda_dir is None: + raise ValueError( + f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly." + ) + + # check for nvcc path + if not os.path.exists(nvcc_path): + raise FileNotFoundError( + f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME." + ) + + # parse the nvcc -v output to obtain the system cuda version + try: + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + except: + raise ValueError( + f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}" + ) + + return bare_metal_major, bare_metal_minor + + +def check_system_pytorch_cuda_match(cuda_dir): + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch() + + if bare_metal_major != torch_cuda_major: + raise Exception( + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." + ) + + if bare_metal_minor != torch_cuda_minor: + warnings.warn( + f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " + "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. " + "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions" + ) + return True + + +def get_pytorch_version() -> List[int]: + """ + This functions finds the PyTorch version. + + Returns: + A tuple of integers in the form of (major, minor, patch). + """ + import torch + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) + return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH + + +def check_pytorch_version(min_major_version, min_minor_version) -> bool: + """ + Compare the current PyTorch version with the minium required version. + + Args: + min_major_version (int): the minimum major version of PyTorch required + min_minor_version (int): the minimum minor version of PyTorch required + + Returns: + A boolean value. The value is True if the current pytorch version is acceptable and False otherwise. + """ + # get pytorch version + torch_major, torch_minor, _ = get_pytorch_version() + + # if the + if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): + raise RuntimeError( + f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) + + +def check_cuda_availability(): + """ + Check if CUDA is available on the system. + + Returns: + A boolean value. True if CUDA is available and False otherwise. + """ + import torch + + return torch.cuda.is_available() + + +def set_cuda_arch_list(cuda_dir): + """ + This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. + Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'. + """ + cuda_available = check_cuda_availability() + + # we only need to set this when CUDA is not available for cross-compilation + if not cuda_available: + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) + + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] + + if int(bare_metal_major) == 11: + if int(bare_metal_minor) == 0: + arch_list.append("8.0") + else: + arch_list.append("8.0") + arch_list.append("8.6") + + arch_list_str = ";".join(arch_list) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str + return False + return True + + +def get_cuda_cc_flag() -> List[str]: + """ + This function produces the cc flags for your GPU arch + + Returns: + The CUDA cc flags for compilation. + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) + for arch in torch.cuda.get_arch_list(): + res = re.search(r"sm_(\d+)", arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) + return cc_flag + + +def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: + """ + This function appends the threads flag to your nvcc args. + + Returns: + The nvcc compilation flags including the threads flag. + """ + from torch.utils.cpp_extension import CUDA_HOME + + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 7c8619ad8f5c..30eb3d9eae40 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -201,6 +201,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + elif self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py deleted file mode 120000 index 4e95c7bfa519..000000000000 --- a/examples/language/llama2/attn.py +++ /dev/null @@ -1 +0,0 @@ -../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py \ No newline at end of file diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py new file mode 100644 index 000000000000..6c048c3b18cf --- /dev/null +++ b/examples/language/llama2/attn.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import math +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + apply_rotary_pos_emb, + repeat_kv, +) + +from colossalai.accelerator import get_accelerator +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + +if get_accelerator().name == "cuda": + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.ops.rms_norm import rms_norm + + def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask + + def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) + ) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) + q, k, v = ( + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), + ) + ) + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len + + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) + + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, + ), + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange( + output, pattern="... h d -> ... (h d)" + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + + def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) + +elif get_accelerator().name == "npu": + import torch_npu + + class NPULlamaAttention(LlamaAttention): + use_flash: bool = True + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.setup() + + def setup(self): + self._softmax_scale = 1 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.use_flash: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + else: + attn_output, *_ = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + self.num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=self._softmax_scale, + padding_mask=None, + pre_tockens=65535, + next_tockens=0, + keep_prob=1.0, + inner_precise=0, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class NPURMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.__class__ = NPULlamaAttention + module.setup() + if isinstance(module, LlamaRMSNorm): + module.__class__ = NPURMSNorm