Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass in simplifying and improving qmm #1030

Merged
merged 3 commits into from Apr 24, 2024
Merged

First pass in simplifying and improving qmm #1030

merged 3 commits into from Apr 24, 2024

Conversation

angeloskath
Copy link
Member

Extracted the loading and dequantizing logic to a reusable QuantizedBlockLoader à la steel. Also removed the reading of scales/biases in smem and reduced the tile size that provides a pretty big speedup.

Before on Mistral 7B 4bit 64 group size:

1000 token prompt: 880 tps
QLoRA validation: 12.8s
QLoRA training: ~533 tps 

after

1000 token prompt: 1100 tps
QLoRA validation: 9.97s
QLoRA training: ~653 tps

all in all a pretty consistent 20%-25% speedup.

Comment on lines 647 to 655

// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK, 1, WM * WN * SIMD_SIZE, group_size, bits>;

threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BN * BK];

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the leading dim of the threadgroup memory should help with bank conflicts
Doing so should be as simple as

  constexpr int BK_padded = (BK + 16 / sizeof(T));

  // Instantiate the appropriate BlockMMA and Loader
  using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
  using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
  using loader_w_t = QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, group_size, bits>;

  threadgroup T Xs[BM * BK_padded];
  threadgroup T Ws[BN * BK_padded];

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow 20% bump for prompt and QLoRA! Awesome!!

@angeloskath
Copy link
Member Author

After @jagrit06 's improvement to avoid memory bank conflicts and remove the checks out of the loop we now get

1000 token prompt: 1230 tps
QLoRA validation: 8.99s
QLoRA training: ~725 tps

for a total improvement of 30%-40%!

@awni
Copy link
Member

awni commented Apr 24, 2024

Incredible!! 🔥

Copy link
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s gooo

@angeloskath angeloskath merged commit 20a01bb into main Apr 24, 2024
3 checks passed
@angeloskath angeloskath deleted the qmm branch April 24, 2024 20:07
Rifur13 pushed a commit to Rifur13/mlx that referenced this pull request Apr 24, 2024
@ivanfioravanti
Copy link

WOW, just WOW

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants