From 05d07645e1ae5eeeff15abda31a6ba5806dd2bb2 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Wed, 7 Feb 2024 10:46:20 +0800 Subject: [PATCH] WOQ: Optimize quantization of activation (#2584) * WOQ: Optimize quantization per-tensor/per-block of activation for lowp-mode=INT8 * Refine threshold of activation size to parallelize quantization --- csrc/cpu/aten/kernels/WoqTppKrnl.cpp | 670 +++++++++++++++++++++------ 1 file changed, 534 insertions(+), 136 deletions(-) diff --git a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp index a4a2dc40a..3e5b46d2b 100644 --- a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp +++ b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp @@ -133,6 +133,7 @@ float dequantize_nf4_scalar(uint8_t val) { // the #else part #if defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET) +#define QUANT_A_THRESHOLD 30720 #define SMALL_BATCH_THRESHOLD 32 #define PARALLEL_M_THRESHOLD 128 constexpr long PREFETCH_K_DIST = 64; // TODO(jgong5): do not hard-code @@ -1666,15 +1667,14 @@ void qlinear_woq_affine_impl( int32_t* zps_a_ptr = nullptr) { const bool is_4bit_flag = is_4bit(qw_type); const bool sym_quant = is_sym_quant(qw_type); - auto x_sizes = x.sizes(); auto w_sizes = qw_packed.sizes(); - auto M = x_sizes[0]; auto Nc = w_sizes[0]; auto Nb = is_4bit_flag ? w_sizes[3] * 2 : w_sizes[3]; auto Kc = w_sizes[1]; auto Kb = w_sizes[2]; auto N = Nc * Nb; auto K = Kc * Kb; + auto M = x.numel() / K; assert(quant_block_k % Kb == 0); auto quant_block_multiple = quant_block_k == 0 ? 1 : quant_block_k / Kb; auto quant_k_blocks = @@ -2514,6 +2514,23 @@ at::Tensor dequantize_nf4(const at::Tensor& qt) { return t; } +template +inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a > b ? a : b; +} + +template +inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a < b ? a : b; +} + +template void compute_int8_qparams_per_tensor( const at::Tensor& t, float* scale, @@ -2527,20 +2544,142 @@ void compute_int8_qparams_per_tensor( *zp = (int32_t)(-std::nearbyint(min / *scale)); } -template -inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { - if (at::_isnan(a)) { - return a; +template <> +void compute_int8_qparams_per_tensor( + const at::Tensor& t, + float* scale, + int32_t* zp) { + auto in_ptr0 = t.data_ptr(); + auto n = t.numel(); + auto K = t.size(-1); + auto M = t.numel() / K; + auto vecsize = at::vec::Vectorized::size(); + auto compute_block = [&](float* in_ptr, int start, int end) { + float min_val = std::numeric_limits::infinity(); + float max_val = -std::numeric_limits::infinity(); + auto min_vec = at::vec::Vectorized(min_val); + auto max_vec = at::vec::Vectorized(max_val); + int i1; + for (i1 = start; i1 < end / n * n; i1 += vecsize) { + auto tmp0 = at::vec::Vectorized::loadu(in_ptr + i1, vecsize); + min_vec = at::vec::minimum(min_vec, tmp0); + max_vec = at::vec::maximum(tmp0, max_vec); + } + for (; i1 < end; i1++) { + auto tmp0 = in_ptr[i1]; + min_val = std::min(min_val, tmp0); + max_val = std::max(tmp0, max_val); + } + min_val = min_propagate_nan( + min_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::minimum(x, y); + }, + min_vec)); + max_val = max_propagate_nan( + max_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + max_vec)); + return std::make_pair(min_val, max_val); + }; + if (n > QUANT_A_THRESHOLD) { + int num_threads = omp_get_max_threads(); + int vec_per_thread = std::ceil((float)n / vecsize / num_threads); + int thread_used = std::ceil((float)n / vecsize / vec_per_thread); + float min_vals[thread_used]; + float max_vals[thread_used]; +#pragma omp parallel for + for (int i0 = 0; i0 < n; i0 += vec_per_thread * vecsize) { + auto vec_start = i0; + auto vec_end = std::min(i0 + vec_per_thread * vecsize, (int)n); + auto [min_val, max_val] = compute_block(in_ptr0, vec_start, vec_end); + min_vals[i0 / vec_per_thread / vecsize] = min_val; + max_vals[i0 / vec_per_thread / vecsize] = max_val; + } + auto min_elem_ptr = std::min_element(min_vals, min_vals + thread_used); + auto max_elem_ptr = std::max_element(max_vals, max_vals + thread_used); + *scale = (*max_elem_ptr - *min_elem_ptr) / 255.0f; + *zp = (int32_t)(-std::nearbyint(*min_elem_ptr / *scale)); + } else { + auto [min_val, max_val] = compute_block(in_ptr0, 0, n); + *scale = (max_val - min_val) / 255.0f; + *zp = (int32_t)(-std::nearbyint(min_val / *scale)); } - return a > b ? a : b; } -template -inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { - if (at::_isnan(a)) { - return a; +template <> +void compute_int8_qparams_per_tensor( + const at::Tensor& t, + float* scale, + int32_t* zp) { + auto in_ptr0 = t.data_ptr(); + auto n = t.numel(); + auto K = t.size(-1); + auto M = t.numel() / K; + auto vecsize = at::vec::Vectorized::size(); + auto compute_block = [&](at::BFloat16* in_ptr, int start, int end) { + float min_val = std::numeric_limits::infinity(); + float max_val = -std::numeric_limits::infinity(); + auto min_vec = at::vec::Vectorized(min_val); + auto max_vec = at::vec::Vectorized(max_val); + int i1; + for (i1 = start; i1 < end / n * n; i1 += vecsize) { + auto tmp0 = + at::vec::Vectorized::loadu(in_ptr + i1, vecsize); + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_bfloat16_float(tmp0); + min_vec = at::vec::minimum(min_vec, res_vec1); + max_vec = at::vec::maximum(res_vec1, max_vec); + } + for (; i1 < end; i1++) { + auto tmp0 = in_ptr[i1]; + min_val = std::min(min_val, (float)tmp0); + max_val = std::max((float)tmp0, max_val); + } + min_val = min_propagate_nan( + min_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::minimum(x, y); + }, + min_vec)); + max_val = max_propagate_nan( + max_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + max_vec)); + return std::make_pair(min_val, max_val); + }; + if (n > QUANT_A_THRESHOLD) { + int num_threads = omp_get_max_threads(); + int vec_per_thread = std::ceil((float)n / vecsize / num_threads); + int thread_used = std::ceil((float)n / vecsize / vec_per_thread); + float min_vals[thread_used]; + float max_vals[thread_used]; +#pragma omp parallel for + for (int i0 = 0; i0 < n; i0 += vec_per_thread * vecsize) { + auto vec_start = i0; + auto vec_end = std::min(i0 + vec_per_thread * vecsize, (int)n); + auto [min_val, max_val] = compute_block(in_ptr0, vec_start, vec_end); + min_vals[i0 / vec_per_thread / vecsize] = min_val; + max_vals[i0 / vec_per_thread / vecsize] = max_val; + } + auto min_elem_ptr = std::min_element(min_vals, min_vals + thread_used); + auto max_elem_ptr = std::max_element(max_vals, max_vals + thread_used); + *scale = (*max_elem_ptr - *min_elem_ptr) / 255.0f; + *zp = (int32_t)(-std::nearbyint(*min_elem_ptr / *scale)); + } else { + auto [min_val, max_val] = compute_block(in_ptr0, 0, n); + *scale = (max_val - min_val) / 255.0f; + *zp = (int32_t)(-std::nearbyint(min_val / *scale)); } - return a < b ? a : b; } template @@ -2548,11 +2687,13 @@ std::pair compute_int8_qparams_per_block( const at::Tensor& t, int quant_block_k, int quant_a_mode) { - int M = t.size(0); - int K = t.size(1); + auto K = t.size(-1); + auto n = t.numel(); + auto M = n / K; + auto t_reshape = t.reshape({M, K}); if (quant_a_mode == QUANT_A_PER_M) { - auto grouped_min = std::get<0>(t.min(-1)); - auto grouped_max = std::get<0>(t.max(-1)); + auto grouped_min = std::get<0>(t_reshape.min(-1)); + auto grouped_max = std::get<0>(t_reshape.max(-1)); auto zeros = at::zeros_like(grouped_min); auto min = at::minimum(grouped_min, zeros); auto max = at::maximum(grouped_max, zeros); @@ -2564,7 +2705,8 @@ std::pair compute_int8_qparams_per_block( int k_rem = K % quant_block_k; int block_k = quant_block_k; auto grouped = - t.index({at::indexing::Slice(), at::indexing::Slice(0, K - k_rem)}) + t_reshape + .index({at::indexing::Slice(), at::indexing::Slice(0, K - k_rem)}) .view({M, K / quant_block_k, quant_block_k}); at::Tensor grouped_min, grouped_max; if (quant_a_mode == QUANT_A_PER_K_BLOCK) { @@ -2581,7 +2723,8 @@ std::pair compute_int8_qparams_per_block( auto zps = -at::round(min / scales); if (k_rem) { auto grouped_rem = - t.index({at::indexing::Slice(), at::indexing::Slice(K - k_rem, K)}) + t_reshape + .index({at::indexing::Slice(), at::indexing::Slice(K - k_rem, K)}) .view({M, 1, k_rem}); at::Tensor grouped_rem_min, grouped_rem_max; if (quant_a_mode == QUANT_A_PER_K_BLOCK) { @@ -2608,8 +2751,9 @@ std::pair compute_int8_qparams_per_block( int quant_block_k, int quant_a_mode) { auto in_ptr = t.data_ptr(); - int M = t.size(0); - int K = t.size(1); + int K = t.size(-1); + int n = t.numel(); + int M = n / K; int Kc = (K + quant_block_k - 1) / quant_block_k; auto vecsize = at::vec::Vectorized::size(); at::Tensor scales, zps; @@ -2704,6 +2848,105 @@ std::pair compute_int8_qparams_per_block( std::move(scales), std::move(zps)); } +template <> +std::pair compute_int8_qparams_per_block( + const at::Tensor& t, + int quant_block_k, + int quant_a_mode) { + auto in_ptr = t.data_ptr(); + int K = t.size(-1); + int n = t.numel(); + int M = n / K; + int Kc = (K + quant_block_k - 1) / quant_block_k; + auto vecsize = at::vec::Vectorized::size(); + at::Tensor scales, zps; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { + scales = at::empty({Kc}, t.options().dtype(at::kFloat)); + zps = at::empty({Kc}, t.options().dtype(at::kInt)); + } else if (quant_a_mode == QUANT_A_PER_M) { + scales = at::empty({M}, t.options().dtype(at::kFloat)); + zps = at::empty({M}, t.options().dtype(at::kInt)); + } else { + scales = at::empty({M, Kc}, t.options().dtype(at::kFloat)); + zps = at::empty({M, Kc}, t.options().dtype(at::kInt)); + } + auto scales_ptr = scales.data_ptr(); + auto zps_ptr = zps.data_ptr(); + auto compute_minmax = [vecsize, scales_ptr, zps_ptr]( + float* ptr, + int M, + int K, + int scale_offset, + int zp_offset, + int ld) { + float min_val = std::numeric_limits::infinity(); + float max_val = -std::numeric_limits::infinity(); + auto in_ptr_ = ptr; + auto min_vec = at::vec::Vectorized(min_val); + auto max_vec = at::vec::Vectorized(max_val); + for (int m = 0; m < M; m++) { + auto in_ptr0 = in_ptr_; + int k; + for (k = 0; k < K / vecsize * vecsize; k += vecsize) { + auto tmp0 = at::vec::Vectorized::loadu(in_ptr0, vecsize); + min_vec = at::vec::minimum(min_vec, tmp0); + max_vec = at::vec::maximum(tmp0, max_vec); + in_ptr0 += vecsize; + } + for (; k < K; k++) { + auto tmp0 = in_ptr0[k]; + min_val = std::min(min_val, tmp0); + max_val = std::max(max_val, tmp0); + } + in_ptr_ += ld; + } + min_val = min_propagate_nan( + min_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::minimum(x, y); + }, + min_vec)); + max_val = max_propagate_nan( + max_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + max_vec)); + scales_ptr[scale_offset] = (max_val - min_val) / 255.0f; + zps_ptr[zp_offset] = + (int32_t)(-std::nearbyint(min_val / scales_ptr[scale_offset])); + }; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { +#pragma omp parallel for + for (int kc = 0; kc < Kc; kc++) { + int offset = kc * quant_block_k; + int block_k = std::min(quant_block_k, K - offset); + compute_minmax(in_ptr + offset, M, block_k, kc, kc, K); + } + } else if (quant_a_mode == QUANT_A_PER_M) { +#pragma omp parallel for + for (int m = 0; m < M; m++) { + int offset = m * K; + compute_minmax(in_ptr + offset, 1, K, m, m, K); + } + } else { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto scale_offset = m * Kc + kc; + auto zp_offset = m * Kc + kc; + int block_k = std::min(quant_block_k, K - kc * quant_block_k); + compute_minmax(in_ptr0, 1, block_k, scale_offset, zp_offset, K); + } + } + } + return std::make_pair( + std::move(scales), std::move(zps)); +} + template at::Tensor quantize_per_tensor(const at::Tensor& t, float scale, int32_t zp) { // TODO(jgong5): optimize me @@ -2712,106 +2955,154 @@ at::Tensor quantize_per_tensor(const at::Tensor& t, float scale, int32_t zp) { return t_q.to(at::kByte); } +template <> +at::Tensor quantize_per_tensor( + const at::Tensor& t, + float scale, + int32_t zp) { +#ifdef __AVX512F__ + at::Tensor out = at::empty_like(t, at::kByte); + auto in_ptr0 = t.data_ptr(); + auto out_ptr0 = out.data_ptr(); + auto n = t.numel(); + auto K = t.size(-1); + auto M = t.numel() / K; + auto vecsize = at::vec::Vectorized::size(); + auto quantize_block = + [vecsize, scale, zp]( + float* in_ptr, int start, int end, uint8_t* out_ptr) { + int i1; + for (i1 = start; i1 < end / vecsize * vecsize; i1 += vecsize) { + auto tmp0 = at::vec::Vectorized::loadu(in_ptr + i1, vecsize); + auto tmp1 = + tmp0 / at::vec::Vectorized(static_cast(scale)); + auto tmp2 = tmp1 + at::vec::Vectorized(static_cast(zp)); + auto tmp3 = tmp2.round(); + auto tmp4 = (tmp3); + auto tmp5 = at::vec::Vectorized(static_cast(0.0)); + auto tmp6 = at::vec::maximum(tmp4, tmp5); + auto tmp7 = at::vec::Vectorized(static_cast(255.0)); + auto tmp8 = at::vec::minimum(tmp6, tmp7); + auto tmp9 = (tmp8); + auto tmp10 = at::vec::convert_float_to_uint8(tmp9); + tmp10.store(out_ptr + i1, vecsize); + } + for (; i1 < end; i1++) { + auto tmp0 = in_ptr[i1]; + auto tmp1 = tmp0 / static_cast(scale); + auto tmp2 = tmp1 + static_cast(zp); + auto tmp3 = std::nearbyint(tmp2); + auto tmp4 = static_cast(tmp3); + auto tmp5 = static_cast(0.0); + auto tmp6 = 0; + if (at::_isnan(tmp4)) { + tmp6 = tmp4; + } + tmp6 = tmp4 > tmp5 ? tmp4 : tmp5; + auto tmp7 = static_cast(255.0); + auto tmp8 = 0; + if (at::_isnan(tmp6)) { + tmp8 = tmp6; + } + tmp8 = tmp6 < tmp7 ? tmp6 : tmp7; + auto tmp9 = static_cast(tmp8); + auto tmp10 = static_cast(tmp9); + out_ptr[i1] = tmp10; + } + }; + if (n > QUANT_A_THRESHOLD) { + int num_threads = omp_get_max_threads(); + int vec_per_thread = std::ceil((float)n / vecsize / num_threads); +#pragma omp parallel for + for (int i0 = 0; i0 < n; i0 += vec_per_thread * vecsize) { + auto vec_start = i0; + auto vec_end = std::min(i0 + vec_per_thread * vecsize, (int)n); + quantize_block(in_ptr0, vec_start, vec_end, out_ptr0); + } + } else { + quantize_block(in_ptr0, 0, n, out_ptr0); + } + return out; +#else + return at::quantize_per_tensor(t, scale, zp, c10::kQUInt8); +#endif +} + template <> at::Tensor quantize_per_tensor( const at::Tensor& t, float scale, int32_t zp) { #ifdef __AVX512F__ - // modified based on inductor codegen... - auto convert_float_to_uint8 = - [](at::vec::Vectorized src) -> at::vec::Vectorized { - // Convert from float32 to int32 - __m512i x_values_int32 = _mm512_cvtps_epi32(src); - - // Convert from int32 to int16 using signed saturation - __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); - - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); - - // Convert from int16 to uint8 using unsigned saturation - __m512i packed_and_sat = _mm512_packus_epi16(xy_packed_v, xy_packed_v); - __m512i xyzw_clamped_v = _mm512_max_epu8( - _mm512_set1_epi8(min_val), - _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); - __m512i permute_mask_v = _mm512_set_epi32( - 0x0f, - 0x0b, - 0x07, - 0x03, - 0x0e, - 0x0a, - 0x06, - 0x02, - 0x0d, - 0x09, - 0x05, - 0x01, - 0x0c, - 0x08, - 0x04, - 0x00); - return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); - }; at::Tensor out = at::empty_like(t, at::kByte); auto in_ptr0 = t.data_ptr(); auto out_ptr0 = out.data_ptr(); auto n = t.numel(); + auto K = t.size(-1); + auto M = t.numel() / K; auto vecsize = at::vec::Vectorized::size(); - auto vec_end = 0; + auto quantize_block = + [vecsize, scale, zp]( + at::BFloat16* in_ptr, int start, int end, uint8_t* out_ptr) { + int i1; + for (i1 = start; i1 < end / vecsize * vecsize; i1 += vecsize) { + auto tmp0 = + at::vec::Vectorized::loadu(in_ptr + i1, vecsize); + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_bfloat16_float(tmp0); + auto tmp1 = res_vec1; + auto tmp2 = at::vec::Vectorized(static_cast(scale)); + auto tmp3 = tmp1 / tmp2; + auto tmp4 = at::vec::Vectorized(static_cast(zp)); + auto tmp5 = tmp3 + tmp4; + auto tmp6 = tmp5.round(); + auto tmp7 = (tmp6); + auto tmp8 = at::vec::Vectorized(static_cast(0.0)); + auto tmp9 = at::vec::maximum(tmp7, tmp8); + auto tmp10 = at::vec::Vectorized(static_cast(255.0)); + auto tmp11 = at::vec::minimum(tmp9, tmp10); + auto tmp12 = (tmp11); + auto tmp13 = at::vec::convert_float_to_uint8(tmp12); + tmp13.store(out_ptr + i1, vecsize); + } + for (; i1 < end; i1++) { + auto tmp0 = in_ptr[i1]; + auto tmp1 = static_cast(tmp0); + auto tmp2 = static_cast(scale); + auto tmp3 = tmp1 / tmp2; + auto tmp4 = static_cast(zp); + auto tmp5 = tmp3 + tmp4; + auto tmp6 = std::nearbyint(tmp5); + auto tmp7 = static_cast(tmp6); + auto tmp8 = static_cast(0.0); + auto tmp9 = 0; + if (at::_isnan(tmp7)) { + tmp9 = tmp7; + } + tmp9 = tmp7 > tmp8 ? tmp7 : tmp8; + auto tmp10 = static_cast(255.0); + auto tmp11 = 0; + if (at::_isnan(tmp9)) { + tmp11 = tmp9; + } + tmp11 = tmp9 < tmp10 ? tmp9 : tmp10; + auto tmp12 = static_cast(tmp11); + auto tmp13 = static_cast(tmp12); + out_ptr[i1] = tmp13; + } + }; + if (n > QUANT_A_THRESHOLD) { + auto num_threads = omp_get_max_threads(); + int vec_per_thread = std::ceil((float)n / vecsize / num_threads); #pragma omp parallel for - for (long i0 = 0; i0 < static_cast(n) / vecsize * vecsize; - i0 += static_cast(vecsize)) { - auto tmp0 = at::vec::Vectorized::loadu( - in_ptr0 + static_cast(i0), vecsize); - at::vec::Vectorized res_vec1(0); - at::vec::Vectorized res_vec2(0); - std::tie(res_vec1, res_vec2) = at::vec::convert_bfloat16_float(tmp0); - auto tmp1 = res_vec1; - // auto tmp1 = cvt_bf16_to_fp32(tmp0); - auto tmp2 = at::vec::Vectorized(static_cast(scale)); - auto tmp3 = tmp1 / tmp2; - auto tmp4 = at::vec::Vectorized(static_cast(zp)); - auto tmp5 = tmp3 + tmp4; - auto tmp6 = tmp5.round(); - auto tmp7 = (tmp6); - auto tmp8 = at::vec::Vectorized(static_cast(0.0)); - auto tmp9 = at::vec::maximum(tmp7, tmp8); - auto tmp10 = at::vec::Vectorized(static_cast(255.0)); - auto tmp11 = at::vec::minimum(tmp9, tmp10); - auto tmp12 = (tmp11); - auto tmp13 = convert_float_to_uint8(tmp12); - tmp13.store(out_ptr0 + static_cast(i0), vecsize); - } - for (long i0 = static_cast(n) / vecsize * vecsize; - i0 < static_cast(n); - i0 += static_cast(1)) { - auto tmp0 = in_ptr0[static_cast(i0)]; - auto tmp1 = static_cast(tmp0); - auto tmp2 = static_cast(scale); - auto tmp3 = tmp1 / tmp2; - auto tmp4 = static_cast(zp); - auto tmp5 = tmp3 + tmp4; - auto tmp6 = std::nearbyint(tmp5); - auto tmp7 = static_cast(tmp6); - auto tmp8 = static_cast(0.0); - // auto tmp9 = max_propagate_nan(tmp7, tmp8); - auto tmp9 = 0; - if (at::_isnan(tmp7)) { - tmp9 = tmp7; + for (int i0 = 0; i0 < n; i0 += vec_per_thread * vecsize) { + auto vec_start = i0; + auto vec_end = std::min(i0 + vec_per_thread * vecsize, (int)n); + quantize_block(in_ptr0, vec_start, vec_end, out_ptr0); } - tmp9 = tmp7 > tmp8 ? tmp7 : tmp8; - auto tmp10 = static_cast(255.0); - auto tmp11 = 0; - if (at::_isnan(tmp9)) { - tmp11 = tmp9; - } - tmp11 = tmp9 < tmp10 ? tmp9 : tmp10; - // auto tmp11 = min_propagate_nan(tmp9, tmp10); - auto tmp12 = static_cast(tmp11); - auto tmp13 = static_cast(tmp12); - out_ptr0[static_cast(i0)] = tmp13; + } else { + quantize_block(in_ptr0, 0, n, out_ptr0); } return out; #else @@ -2826,8 +3117,19 @@ at::Tensor quantize_per_block( const at::Tensor& zp, int quant_block_k, int quant_a_mode) { - int block_k = quant_block_k; - auto grouped = t.view({-1, t.size(-1) / block_k, block_k}); + auto K = t.size(-1); + auto n = t.numel(); + auto M = n / K; + auto k_rem = K % quant_block_k; + bool has_rem = k_rem != 0; + auto K_padded = has_rem ? K + quant_block_k - k_rem : K; + auto t_padded = has_rem + ? at::cat( + {t.reshape({M, K}), + at::zeros({M, quant_block_k - k_rem}, t.options())}, + -1) + : t; + auto grouped = t_padded.view({-1, K_padded / quant_block_k, quant_block_k}); at::Tensor out; if (quant_a_mode == QUANT_A_PER_K_BLOCK) { out = at::clamp( @@ -2842,7 +3144,9 @@ at::Tensor quantize_per_block( out = at::clamp( at::round(grouped / scale.unsqueeze(-1)) + zp.unsqueeze(-1), 0, 255); } - return out.to(at::kByte); + out = out.view({-1, K_padded}) + .index({at::indexing::Slice(), at::indexing::Slice(0, K)}); + return out.to(at::kByte).contiguous(); } template <> @@ -2852,9 +3156,9 @@ at::Tensor quantize_per_block( const at::Tensor& zp, int quant_block_k, int quant_a_mode) { - // t is shape of [M, K] and contiguous tensor - int64_t M = t.size(0); - int64_t K = t.size(1); + int K = t.size(-1); + int n = t.numel(); + int M = n / K; at::Tensor out = at::empty_like(t, at::kByte); int Kc = (K + quant_block_k - 1) / quant_block_k; auto scale_ptr = scale.data_ptr(); @@ -2957,6 +3261,108 @@ at::Tensor quantize_per_block( return out; } +template <> +at::Tensor quantize_per_block( + const at::Tensor& t, + const at::Tensor& scale, + const at::Tensor& zp, + int quant_block_k, + int quant_a_mode) { + int K = t.size(-1); + int n = t.numel(); + int M = n / K; + at::Tensor out = at::empty_like(t, at::kByte); + int Kc = (K + quant_block_k - 1) / quant_block_k; + auto scale_ptr = scale.data_ptr(); + auto zp_ptr = zp.data_ptr(); + auto in_ptr = t.data_ptr(); + auto out_ptr = out.data_ptr(); + auto vecsize = at::vec::Vectorized::size(); + auto quantize_block = + [vecsize]( + float* in_ptr, uint8_t* out_ptr, int block_k, float scale_, int zp_) { + int k; + for (k = 0; k < block_k / vecsize * vecsize; k += vecsize) { + auto in_ptr0 = in_ptr + k; + auto out_ptr0 = out_ptr + k; + auto tmp0 = at::vec::Vectorized::loadu(in_ptr0, vecsize); + auto tmp1 = + tmp0 / at::vec::Vectorized(static_cast(scale_)); + auto tmp2 = + tmp1 + at::vec::Vectorized(static_cast(zp_)); + auto tmp3 = tmp2.round(); + auto tmp4 = (tmp3); + auto tmp5 = at::vec::Vectorized(static_cast(0.0)); + auto tmp6 = at::vec::maximum(tmp4, tmp5); + auto tmp7 = at::vec::Vectorized(static_cast(255.0)); + auto tmp8 = at::vec::minimum(tmp6, tmp7); + auto tmp9 = (tmp8); + auto tmp10 = at::vec::convert_float_to_uint8(tmp9); + tmp10.store(out_ptr0, vecsize); + } + for (; k < block_k; k++) { + auto tmp0 = in_ptr[k]; + auto tmp1 = tmp0 / static_cast(scale_); + auto tmp2 = tmp1 + static_cast(zp_); + auto tmp3 = std::nearbyint(tmp2); + auto tmp4 = static_cast(tmp3); + auto tmp5 = static_cast(0.0); + auto tmp6 = 0; + if (at::_isnan(tmp4)) { + tmp6 = tmp4; + } + tmp6 = tmp4 > tmp5 ? tmp4 : tmp5; + auto tmp7 = static_cast(255.0); + auto tmp8 = 0; + if (at::_isnan(tmp6)) { + tmp8 = tmp6; + } + tmp8 = tmp6 < tmp7 ? tmp6 : tmp7; + auto tmp9 = static_cast(tmp8); + auto tmp10 = static_cast(tmp9); + out_ptr[k] = tmp10; + } + }; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[kc]; + auto zp_ = zp_ptr[kc]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } else if (quant_a_mode == QUANT_A_PER_M) { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[m]; + auto zp_ = zp_ptr[m]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } else { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[m * Kc + kc]; + auto zp_ = zp_ptr[m * Kc + kc]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } + return out; +} + /** * @brief quantized linear with weight in affine quantized format (scale + * zero-point) but activation in floating point format. @@ -3008,7 +3414,6 @@ at::Tensor qlinear_woq_affine( auto out_sizes = x.sizes().vec(); out_sizes.back() = N; auto y = at::empty(out_sizes, x.options()); - auto x_reshape = x.reshape({M, K}); product_dispatcher< std::tuple, std::tuple< @@ -3037,7 +3442,7 @@ at::Tensor qlinear_woq_affine( half, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp16_idx], biases[fp16_idx], @@ -3058,7 +3463,7 @@ at::Tensor qlinear_woq_affine( half, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp16_idx], biases[fp16_idx], @@ -3082,7 +3487,7 @@ at::Tensor qlinear_woq_affine( float, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp32_idx], biases[fp32_idx], @@ -3103,7 +3508,7 @@ at::Tensor qlinear_woq_affine( float, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp32_idx], biases[fp32_idx], @@ -3132,7 +3537,7 @@ at::Tensor qlinear_woq_affine( bfloat16, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[bf16_idx], biases[fp32_idx], @@ -3153,7 +3558,7 @@ at::Tensor qlinear_woq_affine( bfloat16, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[bf16_idx], biases[fp32_idx], @@ -3177,7 +3582,7 @@ at::Tensor qlinear_woq_affine( float, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp32_idx], biases[fp32_idx], @@ -3198,7 +3603,7 @@ at::Tensor qlinear_woq_affine( float, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[fp32_idx], biases[fp32_idx], @@ -3227,7 +3632,7 @@ at::Tensor qlinear_woq_affine( bfloat16, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[bf16_idx], biases[fp32_idx], @@ -3248,7 +3653,7 @@ at::Tensor qlinear_woq_affine( bfloat16, UNQUANT_A, quant_w_mode_>( - x_reshape, + x, qw, scales_list[bf16_idx], biases[fp32_idx], @@ -3273,11 +3678,9 @@ at::Tensor qlinear_woq_affine( if (quant_a_mode == QUANT_A_PER_TENSOR) { float scale_a; int32_t zp_a; - auto x_reshape_contig = x_reshape.contiguous(); - compute_int8_qparams_per_tensor( - x_reshape_contig, &scale_a, &zp_a); - auto x_quantized = quantize_per_tensor( - x_reshape_contig, scale_a, zp_a); + compute_int8_qparams_per_tensor(x, &scale_a, &zp_a); + auto x_quantized = + quantize_per_tensor(x, scale_a, zp_a); qlinear_woq_affine_impl< uint8_t, uint8_t, @@ -3305,16 +3708,11 @@ at::Tensor qlinear_woq_affine( auto block_k = w_sizes[2]; if (quant_block_k <= 0) quant_block_k = block_k; - auto x_reshape_contig = x_reshape.contiguous(); auto [scale_a, zp_a] = compute_int8_qparams_per_block( - x_reshape_contig, quant_block_k, quant_a_mode); + x, quant_block_k, quant_a_mode); auto x_quantized = quantize_per_block( - x_reshape_contig, - scale_a, - zp_a, - quant_block_k, - quant_a_mode); + x, scale_a, zp_a, quant_block_k, quant_a_mode); float* scale_a_ptr = (float*)scale_a.data_ptr(); int32_t* zp_a_ptr = (int32_t*)zp_a.data_ptr(); range_dispatcher<