When compiling the FP_Quantize OP many warnings are shown which are valid and lead to UB or worse.
/deepspeed/ops/csrc/fp_quantizer/fp_quantize_impl.cu(426): warning #62-D: shift count is negative
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
^
detected during:
instantiation of "void apply_selective_dequantization<T,q_mantisa_bits,total_q_bits,_mantisa_bits,_exponent_bits>(uint8_t *, T *, int32_t *, int, int) [with T=__half, q_mantisa_bits=10, total_q_bits=16, _mantisa_bits=3, _exponent_bits=4]" at line 520
instantiation of "void launch_selective_dequantization<T,mantisa>(uint8_t *, T *, int32_t *, int, int, int, int, int, cudaStream_t) [with T=__half, mantisa=10]" at line 532
/deepspeed/ops/csrc/fp_quantizer/fp_quantize_impl.cu(82): warning #68-D: integer conversion resulted in a change of sign
constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
^
detected during:
instantiation of "void apply_quantization<T,unroll,_mantisa_bits,_exponent_bits,total_q_bits,q_mantisa_bits,stochastic_rounding>(T *, uint8_t *, int, std::pair<uint64_t, uint64_t>, float) [with T=__half, unroll=5, _mantisa_bits=23, _exponent_bits=8, total_q_bits=4, q_mantisa_bits=1, stochastic_rounding=1]" at line 353
instantiation of "void launch_quantization<T,mantisa,exponent>(T *, uint8_t *, int, int, cudaStream_t, float, int, int, int) [with T=__half, mantisa=23, exponent=8]" at line 371
template <typename T,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
The other warning is also valid: You shouldn't to bit-shifts with signed types.
Describe the bug
When compiling the FP_Quantize OP many warnings are shown which are valid and lead to UB or worse.
Mostly those:
q_mantisa_bitsis 10 or 7 and_mantisa_bits1, 2, 3 or 7 so this is always negative.Maybe the argument order was confused at
DeepSpeed/csrc/fp_quantizer/fp_quantize_impl.cu
Line 388 in 0ba2352
Compare that with the definition:
The other warning is also valid: You shouldn't to bit-shifts with signed types.