Skip to content

chore: optimize metal backend performance#1669

Merged
AlpinDale merged 1 commit into
mainfrom
chore/optimize-metal
May 5, 2026
Merged

chore: optimize metal backend performance#1669
AlpinDale merged 1 commit into
mainfrom
chore/optimize-metal

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

Qwen3-0.6B, M4 Pro.

Before:

E2E time: 5.80s, TTFT: 0.80s, Prefill: 1906 tokens (2380.3 tokens/s), Decode: 289 tokens (58.0 tokens/s)

After:

E2E time: 2.94s, TTFT: 0.36s, Prefill: 1906 tokens (5277.9 tokens/s), Decode: 353 tokens (138.1 tokens/s)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a contiguous KV cache fast path for Metal (MLX) to optimize performance during low-concurrency dense serving. Key enhancements include native MLX sampling for greedy and random strategies, optimized logit projection to reduce computation, and one-token lookahead prefetching. The PR also implements caching for metadata arrays and block tables within the PagedAttentionContext. Feedback indicates that the native sampling path currently ignores active logits processors, which could bypass penalties or constraints. Additionally, a logic error was identified in the native random sampling implementation that incorrectly handles top_k masking in mixed batches.

Comment on lines +689 to +693
if not (
batch.can_use_native_greedy_for_batch()
or batch.can_use_native_random_for_batch()
):
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The native MLX sampling path (both greedy and random) currently bypasses any active logitsprocs. This means custom plugins, penalties, or constraints implemented as logits processors will be ignored when the fast path is taken. You should check if self._logitsprocs.all is empty before allowing the native path.

        if (
            self._logitsprocs.all
            or not (
                batch.can_use_native_greedy_for_batch()
                or batch.can_use_native_random_for_batch()
            )
        ):

Comment on lines +429 to +430
if not self.all_random:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Native random sampling should only be used if there are no active logits processors, as the native path (_mlx_random_sample) does not apply them.

Suggested change
if not self.all_random:
return False
if self.logitsprocs.all or not self.all_random:
return False

Comment on lines +602 to +643
if max_top_k < batch.vocab_size:
topk_indices = mx.argpartition(-logits, max_top_k - 1, axis=-1)[
:, :max_top_k
]
logits = mx.take_along_axis(logits, topk_indices, axis=-1)
if len(set(top_ks)) != 1:
positions = mx.arange(max_top_k)[None, :]
row_top_ks = mx.array(top_ks, dtype=mx.int32)[:, None]
logits = mx.where(positions < row_top_ks, logits, -float("inf"))

if not batch.no_top_p:
sorted_positions = mx.argsort(-logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_positions, axis=-1)
sorted_indices = mx.take_along_axis(
topk_indices, sorted_positions, axis=-1
)
sorted_probs = mx.softmax(sorted_logits, axis=-1)
top_ps = mx.array(
[
sampling_params.top_p
for sampling_params in batch.sampling_params_list
],
dtype=mx.float32,
)[:, None]
# Keep the first token that crosses top-p, matching nucleus
# sampling's usual "cumulative probability before this token"
# test.
remove = (mx.cumsum(sorted_probs, axis=-1) - sorted_probs) > top_ps
sorted_logits = mx.where(remove, -float("inf"), sorted_logits)
sampled_positions = mx.random.categorical(sorted_logits, axis=-1)
return mx.take_along_axis(
sorted_indices, sampled_positions[:, None], axis=-1
)[:, 0]

sampled_positions = mx.random.categorical(logits, axis=-1)
return mx.take_along_axis(
topk_indices, sampled_positions[:, None], axis=-1
)[:, 0]

topk_values = mx.topk(logits, max_top_k, axis=-1)
topk_thresholds = mx.min(topk_values, axis=-1, keepdims=True)
logits = mx.where(logits < topk_thresholds, -float("inf"), logits)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The else block (lines 641-643) is logically broken for mixed batches where some requests have top_k enabled and others have it disabled (set to vocab_size). In such cases, max_top_k becomes vocab_size, the if block is skipped, and the else block calculates a threshold based on the minimum logit of the entire row, effectively disabling top_k for all requests in the batch.

You should remove the if max_top_k < batch.vocab_size check and always use the argpartition logic if not batch.no_top_k. mx.argpartition handles k = vocab_size - 1 correctly, and the subsequent mx.where (line 610) will correctly apply the per-request top_k masking.

        topk_indices = mx.argpartition(-logits, max_top_k - 1, axis=-1)[
            :, :max_top_k
        ]
        logits = mx.take_along_axis(logits, topk_indices, axis=-1)
        if len(set(top_ks)) != 1 or max_top_k == batch.vocab_size:
            positions = mx.arange(max_top_k)[None, :]
            row_top_ks = mx.array(top_ks, dtype=mx.int32)[:, None]
            logits = mx.where(positions < row_top_ks, logits, -float("inf"))

        if not batch.no_top_p:
            sorted_positions = mx.argsort(-logits, axis=-1)
            sorted_logits = mx.take_along_axis(logits, sorted_positions, axis=-1)
            sorted_indices = mx.take_along_axis(
                topk_indices, sorted_positions, axis=-1
            )
            sorted_probs = mx.softmax(sorted_logits, axis=-1)
            top_ps = mx.array(
                [
                    sampling_params.top_p
                    for sampling_params in batch.sampling_params_list
                ],
                dtype=mx.float32,
            )[:, None]
            remove = (mx.cumsum(sorted_probs, axis=-1) - sorted_probs) > top_ps
            sorted_logits = mx.where(remove, -float("inf"), sorted_logits)
            sampled_positions = mx.random.categorical(sorted_logits, axis=-1)
            return mx.take_along_axis(
                sorted_indices, sampled_positions[:, None], axis=-1
            )[:, 0]

        sampled_positions = mx.random.categorical(logits, axis=-1)
        return mx.take_along_axis(
            topk_indices, sampled_positions[:, None], axis=-1
        )[:, 0]

@AlpinDale AlpinDale merged commit 643efa6 into main May 5, 2026
1 check failed
@AlpinDale AlpinDale deleted the chore/optimize-metal branch May 5, 2026 07:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant