From 12c29f6d8d595f205a166393cff153452108e982 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 17 Apr 2026 04:05:31 +0000 Subject: [PATCH] Fix flaky TestVmap.test_vmap_masked_scatter --- mlx/backend/cuda/indexing.cpp | 2 +- mlx/backend/cuda/scan.cu | 2 +- mlx/backend/gpu/scan.h | 2 +- mlx/backend/metal/scan.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 8cbd443b4d..0bec840463 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -458,7 +458,7 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { } array mask_flat = flatten_in_eval(mask, 1, -1, s); - if (mask_flat.data() != mask.data()) { + if (gpu_ptr(mask_flat) != gpu_ptr(mask)) { encoder.add_temporary(mask_flat); } if (!mask_flat.flags().row_contiguous) { diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 206419e4b8..a7dd2eea25 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -364,7 +364,7 @@ constexpr bool supports_scan_op() { } void scan_gpu_inplace( - array in, + const array& in, array& out, Scan::ReduceType reduce_type, int axis, diff --git a/mlx/backend/gpu/scan.h b/mlx/backend/gpu/scan.h index dab79c50bf..a6dbc6f538 100644 --- a/mlx/backend/gpu/scan.h +++ b/mlx/backend/gpu/scan.h @@ -6,7 +6,7 @@ namespace mlx::core { void scan_gpu_inplace( - array in, + const array& in, array& out, Scan::ReduceType reduce_type, int axis, diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index ede0306c06..cd5184aa3e 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -13,7 +13,7 @@ namespace mlx::core { void scan_gpu_inplace( - array in, + const array& in, array& out, Scan::ReduceType reduce_type, int axis,