New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First pass in simplifying and improving qmm #1030
Conversation
|
||
// Instantiate the appropriate BlockMMA and Loader | ||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>; | ||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>; | ||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE>; | ||
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK, 1, WM * WN * SIMD_SIZE, group_size, bits>; | ||
|
||
threadgroup T scales_block[BN * groups_per_block]; | ||
threadgroup T biases_block[BN * groups_per_block]; | ||
threadgroup T Xs[BM * BK]; | ||
threadgroup T Ws[BN * BK]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the leading dim of the threadgroup memory should help with bank conflicts
Doing so should be as simple as
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow 20% bump for prompt and QLoRA! Awesome!!
After @jagrit06 's improvement to avoid memory bank conflicts and remove the checks out of the loop we now get
for a total improvement of 30%-40%! |
Incredible!! 🔥 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let’s gooo
WOW, just WOW |
Extracted the loading and dequantizing logic to a reusable
QuantizedBlockLoader
à lasteel
. Also removed the reading of scales/biases in smem and reduced the tile size that provides a pretty big speedup.Before on Mistral 7B 4bit 64 group size:
after
all in all a pretty consistent 20%-25% speedup.