Skip to content

Conversation

@max-krasnyansky
Copy link
Collaborator

Similar to #16829 and tested in tandem.

A very simple dynamic chunking mechanism for repack matmuls. Helps on platforms with significant performance difference between the CPU cores, and helps distribute the work better under load in general.
I tested on M4 Pro and a few Snapdragons but it should work on all platforms.

See the details below.
I included a trace with instrumented matmuls that shows how threads threads endup processing chunks.

## M4 Pro

Before (no other load)
| model                  |       size |     params | backend    | ngl | threads | fa | dev   |   test |            t/s |
| ---------------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | -----: |  ------------: |
| gpt-oss 20B MXFP4 MoE  |  11.27 GiB |    20.91 B | Metal      |   0 |       6 |  1 | none  |  pp256 |   75.67 ± 0.43 |
| gpt-oss 20B MXFP4 MoE  |  11.27 GiB |    20.91 B | Metal      |   0 |       6 |  1 | none  |   tg64 |   56.13 ± 0.26 |

| model                  |       size |     params | backend    | ngl | threads | fa | dev   |   test |            t/s |
| ---------------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | -----: | -------------: |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       2 |  1 | none  |  pp256 |  100.81 ± 2.22 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       2 |  1 | none  |   tg64 |   37.27 ± 1.06 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       4 |  1 | none  |  pp256 |  198.34 ± 0.21 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       4 |  1 | none  |   tg64 |   67.88 ± 0.63 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       6 |  1 | none  |  pp256 |  275.03 ± 8.60 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       6 |  1 | none  |   tg64 |   92.09 ± 1.40 |

After (no other load)
| model                  |       size |     params | backend    | ngl | threads | fa | dev   |   test |            t/s |
| ---------------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | -----: |  ------------: |
| gpt-oss 20B MXFP4 MoE  |  11.27 GiB |    20.91 B | Metal      |   0 |       6 |  1 | none  |  pp256 |   76.57 ± 0.15 |
| gpt-oss 20B MXFP4 MoE  |  11.27 GiB |    20.91 B | Metal      |   0 |       6 |  1 | none  |   tg64 |   55.66 ± 0.46 |

| model                  |       size |     params | backend    | ngl | threads | fa | dev   |   test |            t/s |
| ---------------------  | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | -----: | -------------: |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       2 |  1 | none  |  pp256 |  105.01 ± 0.33 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       2 |  1 | none  |   tg64 |   38.63 ± 0.10 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       4 |  1 | none  |  pp256 |  198.66 ± 0.19 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       4 |  1 | none  |   tg64 |   67.40 ± 0.29 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       6 |  1 | none  |  pp256 |  290.01 ± 1.31 |
| llama 3B Q4_0          |   1.78 GiB |     3.21 B | Metal      |   0 |       6 |  1 | none  |   tg64 |   89.92 ± 0.09 |


Chunking in action (no load)
thread-5: matmul ffn_up-0 nchunks 4 usec 7219
thread-3: matmul ffn_up-0 nchunks 4 usec 7221
thread-2: matmul ffn_up-0 nchunks 4 usec 7232
thread-1: matmul ffn_up-0 nchunks 4 usec 7247
thread-0: matmul ffn_up-0 nchunks 4 usec 7259
thread-4: matmul ffn_up-0 nchunks 4 usec 7260
thread-3: matmul ffn_out-0 nchunks 4 usec 7402
thread-1: matmul ffn_out-0 nchunks 4 usec 7423
thread-2: matmul ffn_out-0 nchunks 4 usec 7425
thread-4: matmul ffn_out-0 nchunks 4 usec 7402
thread-0: matmul ffn_out-0 nchunks 4 usec 7411
thread-5: matmul ffn_out-0 nchunks 4 usec 7402

Chunking in action (heavy other load)
thread-3: matmul ffn_up-6 nchunks 3 usec 8080
thread-1: matmul ffn_up-6 nchunks 5 usec 9055
thread-4: matmul ffn_up-6 nchunks 5 usec 9070
thread-5: matmul ffn_up-6 nchunks 5 usec 9428
thread-2: matmul ffn_up-6 nchunks 3 usec 9502
thread-0: matmul ffn_up-6 nchunks 3 usec 10552
thread-3: matmul ffn_out-6 nchunks 4 usec 8556
thread-0: matmul ffn_out-6 nchunks 3 usec 8612
thread-4: matmul ffn_out-6 nchunks 4 usec 8809
thread-1: matmul ffn_out-6 nchunks 5 usec 9275
thread-5: matmul ffn_out-6 nchunks 5 usec 9750
thread-2: matmul ffn_out-6 nchunks 3 usec 9963


## Snapdragon 8-Elite Gen5

## LLama3.2 1B Q4_0
  llama_model_loader: - type  f32:   34 tensors
  llama_model_loader: - type q4_0:  112 tensors
  llama_model_loader: - type q6_K:    1 tensors

Before
| model          |       size |     params | backend    | ngl | threads | fa | dev   |   test |            t/s |
| -------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | -----: | -------------: |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       2 |  1 | none  |  pp128 |  384.94 ± 9.15 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       2 |  1 | none  |   tg64 |   65.17 ± 1.49 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       4 |  1 | none  |  pp128 |  351.52 ± 0.28 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       4 |  1 | none  |   tg64 |   71.00 ± 1.49 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       6 |  1 | none  |  pp128 |  512.93 ± 1.78 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       6 |  1 | none  |   tg64 |   77.26 ± 1.29 |


After
| model          |       size |     params | backend    | ngl | threads | fa | dev   |    test |            t/s |
| -------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | ------: |--------------: |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       2 |  1 | none  |   pp128 |  395.65 ± 7.81 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       2 |  1 | none  |    tg64 |   64.40 ± 0.85 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       4 |  1 | none  |   pp128 |  459.51 ± 1.04 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       4 |  1 | none  |    tg64 |   73.62 ± 0.67 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       6 |  1 | none  |   pp128 |  669.03 ± 1.41 |
| llama 1B Q4_0  | 727.75 MiB |     1.24 B | CPU        |   0 |       6 |  1 | none  |    tg64 |   79.75 ± 0.56 |


## Llama3.2 3B Q4_0
  llama_model_loader: - type  f32:   58 tensors
  llama_model_loader: - type q4_0:  196 tensors
  llama_model_loader: - type q6_K:    1 tensors

Before
| model           |       size |     params | backend    | ngl | threads | fa | dev   |  test |             t/s |
| --------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | ----: | --------------: |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       2 |  1 | none  | pp128 |   127.73 ± 2.43 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       2 |  1 | none  |  tg64 |    27.91 ± 0.61 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       4 |  1 | none  | pp128 |   122.97 ± 0.02 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       4 |  1 | none  |  tg64 |    29.72 ± 1.09 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       6 |  1 | none  | pp128 |  159.59 ± 14.06 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       6 |  1 | none  |  tg64 |    30.33 ± 0.60 |


After
| model           |       size |     params | backend    | ngl | threads | fa | dev   |  test |             t/s |
| --------------- | ---------: | ---------: | ---------- | --: | ------: | -: | ----- | ----: | --------------: |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       2 |  1 | none  | pp128 |   128.16 ± 2.09 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       2 |  1 | none  |  tg64 |    27.46 ± 0.47 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       4 |  1 | none  | pp128 |   161.89 ± 0.30 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       4 |  1 | none  |  tg64 |    30.07 ± 0.65 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       6 |  1 | none  | pp128 |   227.02 ± 7.26 |
| llama 3B Q4_0   |   1.78 GiB |     3.21 B | CPU        |   0 |       6 |  1 | none  |  tg64 |    32.16 ± 0.58 |


## Llama3.2 1B chunking in action

thread-0: matmul ffn_up-11 nchunks 7 usec 143
thread-5: matmul ffn_up-11 nchunks 8 usec 147
thread-3: matmul ffn_up-11 nchunks 2 usec 150
thread-1: matmul ffn_up-11 nchunks 2 usec 150
thread-2: matmul ffn_up-11 nchunks 2 usec 152
thread-4: matmul ffn_up-11 nchunks 3 usec 158
thread-0: matmul ffn_out-11 nchunks 7 usec 124
thread-1: matmul ffn_out-11 nchunks 2 usec 125
thread-5: matmul ffn_out-11 nchunks 8 usec 128
thread-4: matmul ffn_out-11 nchunks 2 usec 129
thread-2: matmul ffn_out-11 nchunks 2 usec 139
thread-3: matmul ffn_out-11 nchunks 3 usec 150


## Galaxy S25+ (Snapdragon 8-Elite Gen4)

## LLama3.2 1B chunking in action

thread-2: matmul ffn_up-11 nchunks 6 usec 147
thread-3: matmul ffn_up-11 nchunks 3 usec 150
thread-0: matmul ffn_up-11 nchunks 6 usec 147
thread-5: matmul ffn_up-11 nchunks 3 usec 150
thread-1: matmul ffn_up-11 nchunks 3 usec 147
thread-4: matmul ffn_up-11 nchunks 3 usec 152
thread-4: matmul ffn_out-11 nchunks 3 usec 136
thread-2: matmul ffn_out-11 nchunks 6 usec 142
thread-5: matmul ffn_out-11 nchunks 3 usec 146
thread-1: matmul ffn_out-11 nchunks 3 usec 136
thread-0: matmul ffn_out-11 nchunks 6 usec 144
thread-3: matmul ffn_out-11 nchunks 3 usec 136

…ing on ARM64

Very similar implementation to the flash-attention chunking, with similar benefits.
@max-krasnyansky
Copy link
Collaborator Author

@slaren any objections to merging this?

@max-krasnyansky
Copy link
Collaborator Author

@slaren Thanks for approving the PR!
I'm thinking of updating the original MatMul and MatMul-ID chunking to use the same logic for computing the chunk size.
i.e Instead of the arbitrary "64 or 16 rows chunk_size" we aim for 4x chunks per thread to take care of any imbalance in the CPUs either simply due ext load or big.LITTLE.
Merging this and will iterate in another PR.

@max-krasnyansky max-krasnyansky merged commit 517b717 into ggml-org:master Oct 30, 2025
72 checks passed
@max-krasnyansky max-krasnyansky deleted the repack-matmul-chunking branch October 30, 2025 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants