Skip to content

Make each thread have its own default stream#3281

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:default-stream-in-threads
Mar 25, 2026
Merged

Make each thread have its own default stream#3281
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:default-stream-in-threads

Conversation

@zcbenz
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz commented Mar 20, 2026

Refs #3078, #3216.

Make sure each thread gets a different stream when using get_default_stream(), which would make multi-thread code safe and lock-free by default.

Changes:

  • Move stream management out of scheduler.
  • Store default streams in thread local storage.
  • Still share all streams to all threads, but require locks to read and write. For most apps it is only accessed when creating new streams, and getting default stream is lock-free.
  • Remove get_stream(int index), which would require locks and is not public API.
  • For multi-devices each device now has its own default stream.

@zcbenz zcbenz force-pushed the default-stream-in-threads branch 5 times, most recently from d7f5291 to 46c181b Compare March 20, 2026 04:11
Thump604 added a commit to Thump604/mlx that referenced this pull request Mar 22, 2026
Metal completion handlers run on dispatch queues where C++ exceptions
cannot propagate — throwing causes std::terminate → SIGABRT, crashing
the process with no diagnostic information.

Instead, store the error message atomically in the CommandEncoder and
check it at the next synchronous point (commit, synchronize). This
converts fatal crashes into catchable runtime_error exceptions that
the application can handle gracefully.

Root cause analysis: the crash at 262K+ context reported as mlx#3216
was actually TWO separate issues:

1. Thread safety in stream management (fixed by PR ml-explore#3281)
2. C++ exceptions thrown from Metal completion handler callbacks
   (fixed by this commit)

The GPU watchdog error (kIOGPUCommandBufferCallbackErrorImpactingInteractivity)
is a separate concern — macOS kills command buffers that block the GPU
beyond the watchdog threshold. This commit ensures that error is reported
as a Python RuntimeError instead of SIGABRT.
@Thump604
Copy link
Copy Markdown

Fused SDPA regression on this branch

While testing my chunked SDPA work (#3293, based on this branch), I discovered that mx.fast.scaled_dot_product_attention produces incorrect results on thread-local-streams but works correctly on origin/main.

Reproduction

import mlx.core as mx
import numpy as np

B, H, qL, D = 1, 2, 256, 128
scale = 1.0 / np.sqrt(D)
mx.random.seed(42)
q = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)
k = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)
v = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)

# Fused SDPA
o_fused = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)

# Manual reference
q32, k32, v32 = q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32)
o_ref = mx.softmax((q32 @ k32.swapaxes(-1, -2)) * scale, axis=-1) @ v32
mx.eval(o_fused, o_ref)

diff = (o_fused.astype(mx.float32) - o_ref).abs()
mx.eval(diff)
print(f"max_diff={diff.max().item():.4f}")  # ~14.4 on this branch, ~0.001 on main

Results

Branch mean|O| fused mean|O| ref max_diff
origin/main 0.0830 0.0830 0.001
thread-local-streams 2.1517 0.0830 14.42

The fused output is ~25x wrong in magnitude. Consistent across all head dims (64, 80, 128, 256) and both float16/bfloat16. The full attention path (steel_attention kernel via sdpa_full_self_attention_metal) is affected — the vector path (decode, qL≤8) is likely fine since it's a different kernel.

The SDPA kernel code itself is identical between the two branches — the regression must come from the stream/device management changes. Possibly get_command_encoder(s.index) returning a different encoder, or buffer bindings being dispatched to the wrong stream.

Models using explicit matmul + softmax (e.g., mlx-lm's attention implementation) are unaffected since they don't use the fused SDPA path.

@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Mar 22, 2026

Hmm the script produces same result on this branch and main, and our tests would have caught it if the result goes wrong.

@Thump604
Copy link
Copy Markdown

Fair enough — I suspect this was a stale Metal JIT cache on my side. I was switching between branches (main, thread-local-streams, and my fix branch) with partial file checkouts and pip reinstalls, which likely left cached kernel binaries from one branch being used with host code from another. The buffer binding changes in my chunked SDPA work (adding write_partial function constants) would cause exactly this kind of mismatch.

Apologies for the noise — I should have done a clean build before reporting. I'll verify with a fresh clone if I see it again.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

I left basically one comment that needs addressing regarding the (natural) assumption that device indices will be less than the available devices. I think the fix should be in the device constructor.

throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
auto& s = default_stream_storage(d);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this is not necessarily a bug in this code but Device can have any index weirdly. ie I can make Device(Device::gpu, 7) and pass it to default_stream which will access out of bounds memory.

So for this code to be correct I think the constructor of Device needs to check that 0 <= index < device_count(dev_type).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking index in Device::Device would break code like is_available(Device::gpu) which constructs an invalid device first and then checks it.

I changed default_stream_storage to do bound check by using default_streams.at(d.index).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense. I thought of that but then I thought it is generally weird that we can create arbitrary devices but maybe that's fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed a weird design, I think the API should be is_available(DeviceType type, int index) which should be compatible with most C++ code but would require API change in mlx-c, not sure if we should change it @andresy.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to approve before. I still think we should fix the device index bug before merging this.

@zcbenz zcbenz force-pushed the default-stream-in-threads branch 3 times, most recently from f07eb26 to a6df64d Compare March 25, 2026 04:14
@zcbenz zcbenz force-pushed the default-stream-in-threads branch from a6df64d to adb6cdb Compare March 25, 2026 04:15
@zcbenz zcbenz merged commit df7f7db into ml-explore:main Mar 25, 2026
16 checks passed
@zcbenz zcbenz deleted the default-stream-in-threads branch March 25, 2026 06:48
@chriscoey
Copy link
Copy Markdown

I was hitting the exact Metal assertion this fixes (A command encoder is already encoding to this command buffer) running concurrent reranker inference via MLX on M-series. Upgraded to 0.31.1 but the fix isn't in that release. Is there a 0.31.2 release planned soon? Happy to test a pre-release build if that's helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants