From 8ace1ef7f692d2cb7ee7ff837f4c8628fa8b776c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 8 Oct 2020 10:39:44 -0700 Subject: [PATCH] Enable CUDA Fuser for ROCm (#45965) Summary: This enables the cuda fuser on ROCm and enables tests for them. Part of this patch is based on work of Rohith Nallamaddi, thank you. Errors are my own, of course. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45965 Reviewed By: seemethere Differential Revision: D24170457 Pulled By: walterddr fbshipit-source-id: 3dd25b3501a41d2f00acba3ce8642ce51c49c9a6 --- codegen.cpp | 4 +++ executor.cpp | 10 ++++-- executor_utils.cpp | 11 ++++++ kernel_resource_strings.h | 73 +++++++++++++++++++++------------------ 4 files changed, 61 insertions(+), 37 deletions(-) diff --git a/codegen.cpp b/codegen.cpp index f6e791f..f976af9 100644 --- a/codegen.cpp +++ b/codegen.cpp @@ -113,7 +113,11 @@ class CudaKernelGenerator : private OptInConstDispatch { // Shared memory if (has_dynamic_smem || has_reductions) { indent() << "alignas(" +#ifndef __HIP_PLATFORM_HCC__ << dataTypeSize(kernel_summary.largest_smem_data_type) +#else + << 8 // for HIP, we want 8-aligned even for smaller datatypes +#endif << ") extern __shared__ char array[];\n"; if (has_dynamic_smem) { diff --git a/executor.cpp b/executor.cpp index a0df3c7..2500381 100644 --- a/executor.cpp +++ b/executor.cpp @@ -25,9 +25,13 @@ int FusionExecutor::fusion_id_counter_ = 0; std::string FusionExecutor::getStructuredCode(const std::string& kernel) { // generating cuda code; - std::string code = std::string("namespace ") + - FusionExecutor::kernelNamespace() + " {\n" + - executor_utils::kernelPreamble() + kernel + "}\n"; + std::string code = ""; +#ifdef __HIP_PLATFORM_HCC__ + code += std::string("#include \n") + + std::string("#include \n"); +#endif + code += std::string("namespace ") + FusionExecutor::kernelNamespace() + + " {\n" + executor_utils::kernelPreamble() + kernel + "}\n"; const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); if (debug_env && atoi(debug_env)) { diff --git a/executor_utils.cpp b/executor_utils.cpp index af4e127..19f873c 100644 --- a/executor_utils.cpp +++ b/executor_utils.cpp @@ -272,10 +272,14 @@ NvrtcFunction nvrtcCompile( at::globalContext().getNVRTC().nvrtcDestroyProgram(&program)); }); +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {"--std=c++14"}; +#else const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor); std::vector args = { "--std=c++14", compute.c_str(), "-default-device"}; +#endif const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA"); // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0; @@ -346,6 +350,7 @@ NvrtcFunction nvrtcCompile( // TODO: We do go through different code path, should investigate whether this // has an impact on generated binary. const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN"); +#ifndef __HIP_PLATFORM_HCC__ if (prefix_env) { FUSER_PERF_SCOPE("load CUBIN"); @@ -403,6 +408,12 @@ NvrtcFunction nvrtcCompile( options.data(), option_vals.data())); } +#else + // load ptx directly + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( + &(compiled_kernel_.module), ptx.data())); + +#endif AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction( &(compiled_kernel_.function), compiled_kernel_.module, diff --git a/kernel_resource_strings.h b/kernel_resource_strings.h index d30eb3f..a601a95 100644 --- a/kernel_resource_strings.h +++ b/kernel_resource_strings.h @@ -12,7 +12,7 @@ typedef long long int int64_t; template struct Tensor { - T& operator[](int64_t ind) { + __device__ T& operator[](int64_t ind) { return data[ind]; }; @@ -25,7 +25,7 @@ struct Tensor { // They will be an error as well since zero-length arrays are not allowed. template struct Tensor { - T& operator[](int64_t) { + __device__ T& operator[](int64_t) { return *data; }; @@ -34,6 +34,9 @@ struct Tensor { )"; // Code support for FP16 __half type and intrinsics +#ifdef __HIP_PLATFORM_HCC__ +static auto code_fp16_support = R"()"; +#else static auto code_fp16_support = R"( #define __HALF_TO_US(var) *(reinterpret_cast(&(var))) #define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) @@ -55,7 +58,7 @@ __device__ float __half2float(const __half h) { return val; } )"; - +#endif // struct and code for functions that need random number generation static auto code_random_number_gen = R"( class Philox { @@ -184,6 +187,9 @@ __device__ float randLike(Philox rnd) { }; )"; +// Note: We agressively template functions taking dim3 in the functions below +// because ROCM uses different types for the various dim3 and maps them +// directly to intrinsics, but they're dim3 when used after modification. /* * EXAMPLE USAGE: * blockReduceSum @@ -196,14 +202,14 @@ static auto code_template_block_reduction = R"( // participate, otherwise it is the number of threads. We could start with warp // reductions, then reduce the warps, this could save some shared memory, but // may actually be slower. -template +template __inline__ __device__ void blockReduce( T& out, const T inp_val, Func reduction_op, - const dim3& thread_idx, - const dim3& block_dim, + const _dim3ti& thread_idx, + const _dim3bd& block_dim, T* shared_mem, bool read_write_pred, T init_val) { @@ -324,49 +330,47 @@ static auto code_template_grid_reduction = R"( namespace reduction { // Utility functions -__host__ __device__ __forceinline__ size_t size(const dim3& d) { +template +__host__ __device__ __forceinline__ size_t size(const _dim3& d) { return (size_t)d.x * (size_t)d.y * (size_t)d.z; } -__host__ __device__ __forceinline__ int isize(const dim3& d) { - return d.x * d.y * d.z; -} +#define isize(d) d.x * d.y * d.z -__host__ __device__ __forceinline__ size_t offset(const dim3& pos, const dim3& dim) { +template +__host__ __device__ __forceinline__ size_t offset(const _dim3pos& pos, const _dim3dim& dim) { return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x + (size_t)pos.z * (size_t)dim.x * (size_t)dim.y; } -__host__ __device__ __forceinline__ size_t ioffset(const dim3& pos, const dim3& dim) { - return pos.x + pos.y * dim.x + pos.z * dim.x * dim.y; -} +#define ioffset(pos, dim) pos.x + pos.y * dim.x + pos.z * dim.x * dim.y // Returns dim3 of each reduction segment. -template -__host__ __device__ dim3 dimension_of_reduction_segment(const dim3& grid_dim) { +template +__host__ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) { return dim3{X_BLOCK ? grid_dim.x : 1, Y_BLOCK ? grid_dim.y : 1, Z_BLOCK ? grid_dim.z : 1}; } // Returns the number of blocks in each reduction segment. -template -__host__ __device__ size_t size_of_reduction_segment(const dim3& grid_dim) { +template +__host__ __device__ size_t size_of_reduction_segment(const _dim3& grid_dim) { return size(dimension_of_reduction_segment(grid_dim)); } // Returns the total number of reduction segments. -template -__host__ __device__ size_t number_of_reduction_segments(const dim3& grid_dim) { +template +__host__ __device__ size_t number_of_reduction_segments(const _dim3& grid_dim) { return (X_BLOCK ? 1: grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) * (Z_BLOCK ? 1 : grid_dim.z); } // Returns the 1-D index of the segment of thread block of block_idx. -template -__host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { +template +__host__ __device__ size_t index_of_reduction_segment(const _dim3bi& block_idx, + const _dim3gd& grid_dim) { size_t seg_idx = 0; if (!Z_BLOCK) seg_idx += block_idx.z; @@ -378,9 +382,9 @@ __host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx, } // Returns the offset of thread block in its reduction segment. -template -__host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx, - const dim3& grid_dim) { +template +__host__ __device__ size_t offset_in_reduction_segment(const _dim3bi& block_idx, + const _dim3gd& grid_dim) { size_t offset = 0; if (Z_BLOCK) offset = offset * grid_dim.z + block_idx.z; @@ -392,23 +396,24 @@ __host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx, } // Returns dim3 of each reduction block. -template -__host__ __device__ dim3 dimension_of_reduction_block(const dim3& block_dim) { +template +__host__ __device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) { return dim3{X_THREAD ? block_dim.x : 1, Y_THREAD ? block_dim.y : 1, Z_THREAD ? block_dim.z : 1}; } // Returns the number of threads of each reduction block. -template -__host__ __device__ int size_of_reduction_block(const dim3& block_dim) { - return isize(dimension_of_reduction_block(block_dim)); +template +__host__ __device__ int size_of_reduction_block(const _dim3& block_dim) { + auto tmp_dim = dimension_of_reduction_block(block_dim); + return isize(tmp_dim); } // Returns the linear offset of a thread in a reduction block. -template -__host__ __device__ int offset_in_reduction_block(const dim3& thread_idx, - const dim3& block_dim) { +template +__host__ __device__ int offset_in_reduction_block(const _dim3ti& thread_idx, + const _dim3bd& block_dim) { int offset = 0; if (Z_THREAD) offset += thread_idx.z;