Skip to content

Comments

Better support consumer CUDA GPUs#3056

Merged
awni merged 1 commit intoml-explore:mainfrom
jessegross:consumer_gpu
Jan 27, 2026
Merged

Better support consumer CUDA GPUs#3056
awni merged 1 commit intoml-explore:mainfrom
jessegross:consumer_gpu

Conversation

@jessegross
Copy link
Contributor

Proposed changes

Currently there are a few places where parameters are set based on checking for particular (primarily data center) GPUs. This extends some checks for consumer GPUs, generalizing where possible to avoid needing to maintain lists. Consumer GPUs tend to have more SKUs and variations, including within the same generation.

There are two places where this is an issue:

  • CUDA graph limits, which are based on hand-tuned values. The default limits aren't able to saturate Blackwell GPUs, such as the RTX 6000 Pro and 5090. Older GPUs seem fine with the existing default limits.
  • JIT use of non-forward compatible architectural features. There is an existing bug which prevents this from ever triggering since it is comparing major version to major and minor. This fixes the issue and broadens the check to all GPUs that support non-forward compatible features.

Benchmarks

I tried to find a generic calculation for tuning the CUDA graph limits, primarily focusing on memory bandwidth. In the end, I couldn’t convince myself that it was a reliable calculation across GPUs and models. As a result, I just extended the existing limits for consumer Blackwell devices, which is where the real performance differences were. Ideally, I would still like to generalize this better in the future.

Performance gains are most significant (42%) with larger models:
RTX Pro 6000 Blackwell
mlx_lm.benchmark --model Qwen/Qwen3-30B-A3B-Thinking-2507 --prompt-tokens 1024 -g 128 -b 1 -n 4

Generation TPS Peak Memory
Before 88.192 61.768
After 125.281 64.475

Smaller improvement but still decent (6%) on with a smaller model (Qwen/Qwen3-4B-Thinking-2507) on the same machine:

Generation TPS Peak Memory
Before 155.513 8.898
After 164.661 11.192

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Currently there are a few places where parameters are set based on
checking for particular (primarily data center) GPUs. This extends
some checks for consumer GPUs, generalizing where possible to
avoid needing to maintain lists. Consumer GPUs tend to have more
SKUs and variations, including within the same generation.

There are two places where this is an issue:
 - CUDA graph limits, which are based on hand-tuned values. The default
   limits aren't able to saturate Blackwell GPUs, such as the RTX 6000
   Pro and 5090. Older GPUs seem fine with the existing default limits.
 - JIT use of non-forward compatible architectural features. There is
   an existing bug which prevents this from ever triggering since it is
   comparing major version to major and minor. This fixes the issue and
   broadens the check to all GPUs that support non-forward compatible
   features.
Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

bool use_sass = compiler_supports_device_sass(device);
auto cc = device.compute_capability_major();
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
std::string arch_tag = (cc >= 9) ? "a" : "";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

Comment on lines +212 to +215
case 1200: // Consumer Blackwell
ops = 100;
mb = 1000;
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting that these numbers are much larger than even a B200. I'm wondering what benchmark did you use to tune it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm I see you used a generation benchmark. I would recommend also checking the compute bound parts or a compute bound workload as these numbers seem a bit high to me.

So you can do a prefill with like 4096 tokens or something or maybe an image generation work load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some additional numbers run on the same RTX Pro 6000 Blackwell as above. I increased both the prompt tokens and tokens generated as well as the number of runs to get better consistency:

Larger model:
mlx_lm.benchmark --model Qwen/Qwen3-30B-A3B-Thinking-2507 --prompt-tokens 4096 -g 1024 -b 1 -n 10

Prompt TPS Generation TPS Peak Memory
Default (20 ops/100 mb) 12821.215 101.402 62.366
B200 (50/500) 12802.336 115.502 64.383
CC 12 from this PR (100/1000) 12827.187 129.529 66.743
Further increase (200/2000) 13413.259 129.897 71.234

Smaller model:
mlx_lm.benchmark --model Qwen/Qwen3-4B-Thinking-2507 --prompt-tokens 4096 -g 1024 -b 1 -n 10

Prompt TPS Generation TPS Peak Memory
Default (20 ops/100 mb) 31278.861 151.717 9.693
B200 (50/500) 31747.421 157.268 12.578
CC 12 from this PR (100/1000) 33152.269 148.148 15.856
Further increase (200/2000) 32839.830 143.471 16.502

For the larger model, these settings are optimal for generation and there is a fairly significant gain compared to those currently used for the B200. As a sanity check, I also see the GPU utilization in nvidia-smi is much lower with the smaller settings. The compute bound prompt processing shows minimal gain as the graph size increases, as expected.

The smaller model peaks at a smaller graph size - I think you were using llama 3.1 8b for your previous tests, so that probably explains the difference in experience. I don't have a B200 to test but my guess that it is that for the same models and tuning objectives, it would have even higher numbers.

It's hard to tune for all cases given the current mechanisms, so I did prioritize larger/more modern models since it seems more likely that's what would be used with these GPUs. Ideally we would extend the mechanisms to better reflect different situations, though there's quite a few variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing more results. Indeed it seems like we need to revisit tuning on the larger devices. Usually when I tune these I tune with at least two pretty different workloads. I think you are right I was using Llama 3.1 8b for inference tuning and I can't recall but I think a small LM for pretraining as well. Let's merge what you have as it's a very nice improvement but keep an eye on how to improve this for the future!

@awni awni merged commit fed0fe3 into ml-explore:main Jan 27, 2026
16 checks passed
@jessegross jessegross deleted the consumer_gpu branch February 2, 2026 19:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants