From c38028548a275a67ae81652e7914020a8db1d761 Mon Sep 17 00:00:00 2001 From: Valeriy Sofin Date: Fri, 29 May 2026 02:17:58 +0300 Subject: [PATCH] Fix CUDA all-reduce planning for large inputs --- mlx/backend/cuda/reduce/all_reduce.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 0659504c5b..167a282afb 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -68,12 +68,12 @@ void all_reduce( out.set_data(cu::malloc_async(out.nbytes(), encoder)); - auto get_args = [](int size, int N) { - int threads = std::min(512, (size + N - 1) / N); + auto get_args = [](size_t size, size_t N) { + int threads = + static_cast(std::min(size_t{512}, cuda::ceil_div(size, N))); threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - int reductions_per_step = threads * N; - size_t steps_needed = - (size + reductions_per_step - 1) / reductions_per_step; + size_t reductions_per_step = threads * N; + size_t steps_needed = cuda::ceil_div(size, reductions_per_step); int blocks; if (steps_needed < 32) { @@ -88,7 +88,7 @@ void all_reduce( blocks = 1024; } - size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t steps_per_block = cuda::ceil_div(steps_needed, blocks); size_t block_step = steps_per_block * reductions_per_step; return std::make_tuple(blocks, threads, block_step);