diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 97ebfbaa0a13..91b8abf1a162 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -10,6 +10,8 @@ quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (rank < 0 || rank >= world_size) @@ -20,9 +22,11 @@ quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { } void qr_destroy(quickreduce::fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - fa->destroy(); - delete fa; + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } } torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 7abbc18aea73..8138eda82875 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -104,6 +104,10 @@ __quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { } #endif } +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; template __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, @@ -152,10 +156,11 @@ __quickreduce_device_inline__ int packed_max(int a, int b) { template <> __quickreduce_device_inline__ int packed_max(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hmax2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; } template @@ -170,10 +175,11 @@ __quickreduce_device_inline__ int packed_min(int a, int b) { template <> __quickreduce_device_inline__ int packed_min(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hmin2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; } template @@ -194,15 +200,12 @@ __quickreduce_device_inline__ int packed_abs_max(int a, int b) { template <> __quickreduce_device_inline__ int packed_abs_max(int a, int b) { - nv_bfloat162 wmaxh2 = *(reinterpret_cast(&a)); - nv_bfloat162 wminh2 = *(reinterpret_cast(&b)); - nv_bfloat162 wblockmaxh2; - wblockmaxh2.x = - __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; - wblockmaxh2.y = - __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; - - return *(reinterpret_cast(&wblockmaxh2)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; } template @@ -217,10 +220,11 @@ __quickreduce_device_inline__ int packed_add(int a, int b) { template <> __quickreduce_device_inline__ int packed_add(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hadd2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; } template <> @@ -246,10 +250,11 @@ __quickreduce_device_inline__ int packed_sub(int a, int b) { template <> __quickreduce_device_inline__ int packed_sub(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hsub2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; } template @@ -280,78 +285,21 @@ __quickreduce_device_inline__ int packed_rcp(int a) { template <> __quickreduce_device_inline__ int packed_rcp(int a) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162 tR = h2rcp(*tA); - return *(reinterpret_cast(&tR)); + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; } -template -__quickreduce_device_inline__ T float2T_cast(float a); - -template <> -__quickreduce_device_inline__ half float2T_cast(float a) { - return __float2half(a); -} - -template <> -__quickreduce_device_inline__ nv_bfloat16 float2T_cast(float a) { - return __float2bfloat16(a); -} - -template -__quickreduce_device_inline__ float T2float_cast(T a); - -template <> -__quickreduce_device_inline__ float T2float_cast(half a) { +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { return __half2float(a); } -template <> -__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } -template -__quickreduce_device_inline__ unsigned char T2uchar_cast(T a); - -template <> -__quickreduce_device_inline__ unsigned char T2uchar_cast(half a) { - return static_cast(__half2ushort_rz(a)); -} - -template <> -__quickreduce_device_inline__ unsigned char T2uchar_cast( - nv_bfloat16 a) { - return static_cast(__bfloat16_as_ushort(a)); -} - -template -__quickreduce_device_inline__ T uchar2T_cast(unsigned char a); - -template <> -__quickreduce_device_inline__ half uchar2T_cast(unsigned char a) { - return __ushort2half_rz(static_cast(a)); -} - -template <> -__quickreduce_device_inline__ nv_bfloat16 -uchar2T_cast(unsigned char a) { - return __ushort_as_bfloat16(static_cast(a)); -} - -template -__quickreduce_device_inline__ int T2int_cast(T a); - -template <> -__quickreduce_device_inline__ int T2int_cast(half a) { - return __half2int_rz(a); -} - -template <> -__quickreduce_device_inline__ int T2int_cast(nv_bfloat16 a) { - return static_cast(__bfloat16_as_ushort(a)); -} - template __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; @@ -384,45 +332,6 @@ __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { return wblockmax; } -template -__quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, - int& wblockmin, - int valid_data) { - const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; - static constexpr int FP_MAX = - std::is_same::value ? 0x7BFF7BFF : 0x7F7F7F7F; - static constexpr int FP_MIN = - std::is_same::value ? 0xFBFFFBFF : 0xFF7FFF7F; - - int wmax, wmin; - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - // In case the data was loaded out of range (and initialized to 0) - // we set max min values to sentinel values - // so that they do not spoil the group max min values - wmax = valid_data * packed_max(a, b) + (!valid_data) * FP_MIN; - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - wmin = valid_data * packed_min(a, b) + (!valid_data) * FP_MAX; - - // Reduce the max and min among a group of threads - // Note: This is basically 2 blocks of values setup as the - // upper/lower halves of the f16x2_t - for (int i = 1; i < kThreadGroupSize; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } - - // Share with the cohort - wblockmax = __shfl(wmax, group_leader); - wblockmin = __shfl(wmin, group_leader); -} - __quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) { __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 2cd5ecf25657..0a345772bd3c 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -94,8 +94,9 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, enum QuickReduceQuantLevel { FP16 = 0, - INT8, - INT4, + INT8 = 1, + INT6 = 2, + INT4 = 3, }; struct DeviceComms { @@ -103,7 +104,7 @@ struct DeviceComms { static long constexpr kTileSize = 256 * 16 * 8; // Max problem size is 2GB (in bytes) or half of uint32_t max value. - static int64_t constexpr kMaxProblemSize = 2147483647; + static int64_t constexpr kMaxProblemSize = 2147483648; static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize; // Max TP-8 @@ -220,6 +221,9 @@ struct DeviceComms { case QuickReduceQuantLevel::INT8: TWOSHOT_DISPATCH(CodecQ8) break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; case QuickReduceQuantLevel::INT4: TWOSHOT_DISPATCH(CodecQ4) break; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 813f9bd620fa..89a07629d713 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -56,55 +56,57 @@ struct CodecFP : public CodecBase { // We quantize the FP16 data to block-scaled Int4 in blocks of 4 * // kThreadGroupSize. template -struct CodecQ8 : public CodecBase { +struct CodecQ4 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. - // Each threads processes a fragment of f16x8_t (16B), - // into a int8x8_t (8B) and a f16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 2176; - static int constexpr kRankTileScaleOffset = 2048; - static int constexpr kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); + "kRankTransmittedTileSize must be 16B aligned."); - static int constexpr kRankBufferTileStride = + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTransmittedTileSize = + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; // Constants configuration - // {-1/128.0h, -1/128.0h}, f16x2_t - static int constexpr kScaleFactor = - std::is_same::value ? 0xA000A000 : 0xBC00BC00; + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xB000B000 : 0xBE00BE00; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-128, -128}, f16x2_t - static int constexpr kRangeMin = - std::is_same::value ? 0xD800D800 : 0xC300C300; - // {+127, +127}, f16x2_t - static int constexpr kRangeMax = - std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; - // {+128, +128}, int16x2_t - static int constexpr kRangeBias = 0x00800080; + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; - __quickreduce_device_inline__ CodecQ8(int thread, int rank) + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) { set_fp16_ovfl(true); } __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + const int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group // In 2 blocks of values, upper/lower halves of the f16x2_t int wblockmax = group_abs_max(atom); @@ -129,24 +131,21 @@ struct CodecQ8 : public CodecBase { { int16_t* qi = reinterpret_cast(&q); T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { q[i] = packed_add(q[i], kRangeBias); } } - // Pack 8 x q8 into int32x2_t - int32x2_t qw; - qw[0] = q[0] | (q[1] << 8); - qw[1] = q[2] | (q[3] << 8); + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); // Write quantized atom to send_buffer // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); @@ -162,34 +161,30 @@ struct CodecQ8 : public CodecBase { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int32_t qw = __builtin_nontemporal_load(qw_ptr); int qs = __builtin_nontemporal_load(qs_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q8 into fp16x8_t + // Unpack q4 into f16x8_t int32x4_t w; { - static uint constexpr kMask00FF = 0x00FF00FF; - - // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1024 = 0x64006400; - - // {-1152.0, -1152.0}, fp16x2_t - static uint constexpr kHalf2_1152 = 0xE480E480; + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t -#pragma unroll for (int i = 0; i < 4; i++) { if constexpr (std::is_same::value) { - int32_t q8 = - ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; - w[i] = packed_add(q8, kHalf2_1152); + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + packed_add(w[i], kHalf2_1032); } else { - int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); @@ -211,19 +206,20 @@ struct CodecQ8 : public CodecBase { } }; -// Int4 symmetric quantization codec. -// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * // kThreadGroupSize. template -struct CodecQ4 : public CodecBase { +struct CodecQ6 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. // Each threads processes a fragment of fp16x8_t (16B), - // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. static constexpr int kRankAtoms = kAtoms / kWorldSize; - static constexpr int kRankTileStride = 1152; - static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); @@ -237,29 +233,27 @@ struct CodecQ4 : public CodecBase { // Constants configuration - // {-1/8.0h, -1/8.0h}, f16x2_t + // {-1/32.0h, -1/32.0h}, fp16x2_t static constexpr int kScaleFactor = - std::is_same::value ? 0xB000B000 : 0xBE00BE00; + std::is_same::value ? 0xA800A800 : 0xBD00BD00; - // {1e-7, 1e-7}, f16x2_t + // {1e-7, 1e-7}, fp16x2_t static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-8, -8}, f16x2_t + // {-32, -32}, fp16x2_t static constexpr int kRangeMin = - std::is_same::value ? 0xC800C800 : 0xC100C100; + std::is_same::value ? 0xD000D000 : 0xC200C200; - // {+7, +7}, f16x2_t + // {+31, +31}, fp16x2_t static constexpr int kRangeMax = - std::is_same::value ? 0x47004700 : 0x40E040E0; + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; - // {+8, +8}, int16x2_t - static constexpr int kRangeBias = 0x00080008; + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; - __quickreduce_device_inline__ CodecQ4(int thread, int rank) - : CodecBase(thread, rank) { - set_fp16_ovfl(true); - } + __quickreduce_device_inline__ CodecQ6(int thread, int rank) + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { @@ -290,26 +284,37 @@ struct CodecQ4 : public CodecBase { { int16_t* qi = reinterpret_cast(&q); T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { q[i] = packed_add(q[i], kRangeBias); } } - // Pack 8 x q4 into int32_t - int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); - + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } // Write quantized atom to send_buffer // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - __builtin_nontemporal_store(qw, qw_ptr); + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); if (threadIdx.x == group_leader) { __builtin_nontemporal_store(decoding_scale, qs_ptr); } @@ -321,32 +326,44 @@ struct CodecQ4 : public CodecBase { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - int32_t qw = __builtin_nontemporal_load(qw_ptr); + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); int qs = __builtin_nontemporal_load(qs_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q4 into f16x8_t + // Unpack q6 into fp16x8_t int32x4_t w; { - static constexpr uint kMask000F = 0x000F000F; - static constexpr uint kHalf2_1024 = + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1032 = - 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t +#pragma unroll for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; if constexpr (std::is_same::value) { - int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; - packed_add(w[i], kHalf2_1032); + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); } else { - int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int32_t int16_2 = q4 | (q2 << 4); int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); @@ -361,111 +378,166 @@ struct CodecQ4 : public CodecBase { w[i] = packed_mul(w[i], qs); } + // That's pretty much it... data[k] = w; } } }; -// Oneshot AllReduce +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. template -struct AllReduceOneshot { - static_assert(sizeof(T) == 2); +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; - __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - uint32_t const N, // number of elements - uint32_t const rank, // rank index - uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - uint32_t flag_color) { - BufferResource src_buffer(const_cast(A), N * sizeof(T)); - BufferResource dst_buffer(B, N * sizeof(T)); + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); - uint8_t* rank_buffer = buffer_list[rank]; + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); - const int block_size = blockDim.x; - const int thread = threadIdx.x; - const int block = blockIdx.x; - const uint32_t problem_size = (N + 3) / 4; + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; - int32x4_t tA, tB; - long grid = gridDim.x; - long data_stride = grid * block_size * sizeof(int32x4_t); - long comm_flags0_offset = block * (world_size * sizeof(int)); - long comm_flags1_offset = - comm_flags0_offset + grid * (world_size * sizeof(int)); + // Constants configuration - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - // load values - tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; - // Write rank data into this rank segment of every rank's communication - // buffer. -#pragma unroll - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + data_offset + rank * data_stride + - idx * sizeof(int32x4_t)); - __builtin_nontemporal_store(tA, send_buffer); - } - } + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags0_offset + r * sizeof(int)); + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) { + set_fp16_ovfl(true); + } + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } - } - __syncthreads(); - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { + // Convert from f16x2_t to uint16x2_t + int32x4_t q; { - int r = 0; - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tA = __builtin_nontemporal_load(recv_buffer); - } -#pragma unroll - for (int r = 1; r < world_size; r++) { - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tB = __builtin_nontemporal_load(recv_buffer); + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); - // Reduce the local data with the read data - packed_assign_add(&tA, &tB); + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } } - buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } } + } - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags1_offset + r * sizeof(int)); + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); } + + data[k] = w; } } }; @@ -605,4 +677,108 @@ struct AllReduceTwoshot { } }; +// Oneshot AllReduce +template +struct AllReduceOneshot { + static_assert(sizeof(T) == 2); + + __device__ static void run( + T const* __restrict__ A, // input + T* __restrict__ B, // output + uint32_t const N, // number of elements + uint32_t const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + BufferResource dst_buffer(B, N * sizeof(T)); + + uint8_t* rank_buffer = buffer_list[rank]; + + const int block_size = blockDim.x; + const int thread = threadIdx.x; + const int block = blockIdx.x; + const uint32_t problem_size = (N + 3) / 4; + + int32x4_t tA, tB; + long grid = gridDim.x; + long data_stride = grid * block_size * sizeof(int32x4_t); + long comm_flags0_offset = block * (world_size * sizeof(int)); + long comm_flags1_offset = + comm_flags0_offset + grid * (world_size * sizeof(int)); + + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + // load values + tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); + + // Write rank data into this rank segment of every rank's communication + // buffer. +#pragma unroll + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + data_offset + rank * data_stride + + idx * sizeof(int32x4_t)); + __builtin_nontemporal_store(tA, send_buffer); + } + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags0_offset + r * sizeof(int)); + + // Wait for the flags to be set. + while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { + } + } + __syncthreads(); + + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + { + int r = 0; + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tA = __builtin_nontemporal_load(recv_buffer); + } +#pragma unroll + for (int r = 1; r < world_size; r++) { + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tB = __builtin_nontemporal_load(recv_buffer); + + // Reduce the local data with the read data + packed_assign_add(&tA, &tB); + } + + buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags1_offset + r * sizeof(int)); + + // Wait for the flags to be set. + while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + } +}; + } // namespace quickreduce \ No newline at end of file diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index e6b5debc0184..322633c220a4 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -24,8 +24,9 @@ class QuickReduceRegime(Enum): FP = 0 INT8 = 1 - INT4 = 2 - NONE = 3 + INT6 = 2 + INT4 = 3 + NONE = 4 class QuickAllReduce: @@ -50,8 +51,8 @@ def __init__(self, group: ProcessGroup, "Supported levels: " f"{list(QuickReduceRegime.__members__.keys())}") if regime_str == "NONE": - logger.debug("Custom quickreduce is disabled based on " - "env variable VLLM_ROCM_CA_QUANT_REGIME") + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_CA_QUANT_REGIME") return self.quant_level = QuickReduceRegime[regime_str] # On RocM bfloat16 kernels are slower than fp16 @@ -106,7 +107,7 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): Performs an out-of-place all reduce. """ inp_size = inp.numel() * inp.element_size() - if inp_size >= self.max_size: + if inp_size > self.max_size: return None inp_dtype = inp.dtype diff --git a/vllm/envs.py b/vllm/envs.py index 2827703e3502..56443666e584 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -672,8 +672,8 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), - # Custom quick allreduce kernel for MI3* cards. - # Choice of quantization level: FP, INT8, INT4 or NONE + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce "VLLM_ROCM_CA_QUANT_REGIME": lambda: os.getenv("VLLM_ROCM_CA_QUANT_REGIME", "FP").upper(),