Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INT8 quantization support #45

Open
casper-hansen opened this issue Sep 11, 2023 · 3 comments
Open

INT8 quantization support #45

casper-hansen opened this issue Sep 11, 2023 · 3 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@casper-hansen
Copy link
Owner

casper-hansen commented Sep 11, 2023

The motivation for INT8 is to keep even more accuracy while still getting some gains on inference speed. I experimented with implementing dequantization for INT8 and ultimately need more work on this before it will be usable.

Edit: Implement SmoothQuant instead. Here is a fork of SmoothQuant that supports LLaMa models. Integrate this into AutoAWQ. https://github.com/AniZpZ/smoothquant/tree/llama-dev

__device__ uint8_t dequantize_s8_to_fp16x2(uint32_t const& source)
{
    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L54
    uint8_t result;

    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);

    // Casper: Original was 0x64646464 = {1124, 1124}
    // Optimize to 0x64806480 because divisible by 8, 16, 32, 64, 128
    // NOTE: Test out {1280, 1280} since it's also divisible by 256
    static constexpr uint32_t mask_for_elt_01     = 0x5250;
    static constexpr uint32_t mask_for_elt_23     = 0x5351;
    static constexpr uint32_t start_byte_for_fp16 = 0x64806480; 
    asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
    asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));

    // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
    // Casper 0x64806480 = {1152, 1152}
    static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; 
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
}
@casper-hansen casper-hansen added enhancement New feature or request help wanted Extra attention is needed labels Sep 11, 2023
@casper-hansen casper-hansen mentioned this issue Sep 11, 2023
30 tasks
@yunfeng-scale
Copy link

How would you compare this with 8 bit bitsandbytes? i think bitsandbytes have minimal performance loss

@casper-hansen
Copy link
Owner Author

How would you compare this with 8 bit bitsandbytes? i think bitsandbytes have minimal performance loss

It is not implemented yet, so I cannot speak to it

@casper-hansen
Copy link
Owner Author

#71 is working on INT8 support. Still things left to be implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants