Skip to content

Commit

Permalink
Rocm warp size fix (#5402)
Browse files Browse the repository at this point in the history
This PR enables building the below extensions for AMD GPUs with warp
size 32.
- transformer_inference
- quantizer
- random_ltd


This PR works stand-alone for torch version <=2.0. For the latest
versions, #5401 is required
to be merged in addition to this PR.

Unit test results (rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2)
on NAVI3x:

**transformer_inference:**
pytest --color=yes --durations=0 --verbose -s -m "inference_ops" -rF -n
4 unit/ops/transformer/inference

Before this PR:
===== 674 failed, 622 skipped, 8 warnings, 1728 errors in 69.37s
(0:01:09) =====

After this PR:
========== 476 failed, 1062 passed, 1486 skipped, 8 warnings in 9.31s
==========

**quantizer:**
pytest --color=yes --durations=0 --verbose -s -m "inference_ops" -rF -n
4 unit/ops/quantizer

Before this PR:
     ==== 244 failed, 8 warnings in 30.53s ====

After this PR:
    ====== 186 failed, 58 passed, 8 warnings in 8.89s ======

I could not find random_ltd related unit tests to run.

Fixes: 
#4753
#5474
ROCm#68

cc: @jithunnair-amd

---------

Co-authored-by: rraminen@amd.com <rraminen>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
rraminen and loadams committed May 17, 2024
1 parent d3dd8e7 commit 76c9c69
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 13 deletions.
2 changes: 1 addition & 1 deletion csrc/includes/ds_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ used throughout the codebase.
#ifdef __HIP_PLATFORM_AMD__

// constexpr variant of warpSize for templating
constexpr int hw_warp_size = 64;
constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE;
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_fp16.h>
Expand Down
2 changes: 1 addition & 1 deletion csrc/random_ltd/token_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ constexpr int mem_vals = granularity / sizeof(int32_t);
constexpr int max_buffer_size = (threads + 1) * mem_vals;

#ifdef __HIP_PLATFORM_AMD__
constexpr int warp_size = 64;
constexpr int warp_size = ROCM_WAVEFRONT_SIZE;
#else
constexpr int warp_size = 32;
#endif
Expand Down
12 changes: 2 additions & 10 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
rope_theta, \
max_out_tokens);

#ifdef __HIP_PLATFORM_AMD__
#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
if (threads_per_head == 4) { \
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
} else if (threads_per_head == 8) { \
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
} else if (threads_per_head == 16) { \
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
} else if (threads_per_head == 32) { \
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
} else if (threads_per_head == 64) { \
if (threads_per_head == 64) { \
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
} else { \
assert(false); \
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ used throughout the codebase.
#ifdef __HIP_PLATFORM_AMD__

// constexpr variant of warpSize for templating
constexpr int hw_warp_size = 64;
constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE;
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_fp16.h>
Expand Down
36 changes: 36 additions & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def assert_no_cuda_mismatch(name=""):

class OpBuilder(ABC):
_rocm_version = None
_rocm_gpu_arch = None
_rocm_wavefront_size = None
_is_rocm_pytorch = None
_is_sycl_enabled = None
_loaded_ops = {}
Expand Down Expand Up @@ -229,6 +231,32 @@ def installed_rocm_version():
OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
return OpBuilder._rocm_version

@staticmethod
def get_rocm_gpu_arch():
if OpBuilder._rocm_gpu_arch:
return OpBuilder._rocm_gpu_arch
rocm_gpu_arch_cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'"
try:
result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True)
rocm_gpu_arch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
rocm_gpu_arch = ""
OpBuilder._rocm_gpu_arch = rocm_gpu_arch
return OpBuilder._rocm_gpu_arch

@staticmethod
def get_rocm_wavefront_size():
if OpBuilder._rocm_wavefront_size:
return OpBuilder._rocm_wavefront_size
rocm_wavefront_size_cmd = "/opt/rocm/bin/rocminfo | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'"
try:
result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True)
rocm_wavefront_size = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
rocm_wavefront_size = "32"
OpBuilder._rocm_wavefront_size = rocm_wavefront_size
return OpBuilder._rocm_wavefront_size

def include_paths(self):
'''
Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
Expand Down Expand Up @@ -520,6 +548,8 @@ def jit_load(self, verbose=True):

if self.is_rocm_pytorch():
cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()
cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())

op_module = load(name=self.name,
sources=self.strip_empty_entries(sources),
Expand Down Expand Up @@ -650,6 +680,12 @@ def builder(self):

if self.is_rocm_pytorch():
compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
#cxx compiler args are required to compile cpp files
compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())
#nvcc compiler args are required to compile hip files
compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())
if self.get_rocm_gpu_arch():
os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()

cuda_ext = ExtensionBuilder(name=self.absolute_name(),
sources=self.strip_empty_entries(self.sources()),
Expand Down

0 comments on commit 76c9c69

Please sign in to comment.