diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index fff2445297..6fc9528289 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" +#include "mlx/utils.h" #include @@ -136,8 +137,8 @@ void gemm_conv_nd( ConvParams& params, Stream s) { // Get gemm shapes. - int mat_M = out.size() / params.O; // N * H_out * W_out - int mat_K = wt.size() / params.O; // C * H_wt * W_wt + int mat_M = safe_cast(out.size() / params.O, "conv"); // N * H_out * W_out + int mat_K = safe_cast(wt.size() / params.O, "conv"); // C * H_wt * W_wt int mat_N = params.O; // O // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu index f2688b3096..4060d744d7 100644 --- a/mlx/backend/cuda/conv/gemm_grouped_conv.cu +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" +#include "mlx/utils.h" #include @@ -141,8 +142,9 @@ void gemm_grouped_conv_nd( // Get gemm shapes. int C_per_group = params.C / params.groups; int O_per_group = params.O / params.groups; - int mat_M = out.size() / params.O; // N * H_out * W_out - int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt + int mat_M = safe_cast(out.size() / params.O, "conv"); // N * H_out * W_out + int mat_K = + safe_cast(wt.size() / params.O, "conv"); // C_per_group * H_wt * W_wt int mat_N = O_per_group; // O_per_group // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 5d032779d3..ce6f718448 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -39,8 +39,8 @@ void explicit_gemm_conv_ND_gpu( array& out, const MLXConvParams& conv_params) { // Get gemm shapes - int implicit_M = out.size() / conv_params.O; - int implicit_K = wt.size() / conv_params.O; + int implicit_M = safe_cast(out.size() / conv_params.O, "conv"); + int implicit_K = safe_cast(wt.size() / conv_params.O, "conv"); int implicit_N = conv_params.O; // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K}; @@ -113,8 +113,8 @@ void explicit_gemm_conv_group_ND_gpu( const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Get gemm shapes - const int implicit_M = out.size() / conv_params.O; - const int implicit_K = wt.size() / conv_params.O; + const int implicit_M = safe_cast(out.size() / conv_params.O, "conv"); + const int implicit_K = safe_cast(wt.size() / conv_params.O, "conv"); const int implicit_N = O_per_group; int kernel_size = 1; @@ -200,7 +200,10 @@ void implicit_gemm_conv_2D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + const int implicit_M = safe_cast( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1], + "conv"); const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group; @@ -329,7 +332,10 @@ void implicit_gemm_conv_2D_general_gpu( array& out, const MLXConvParams<2>& conv_params) { // Deduce implicit gemm size - int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + int implicit_M = safe_cast( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1], + "conv"); int implicit_N = conv_params.O; int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; @@ -512,8 +518,10 @@ void implicit_gemm_conv_3D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = - conv_params.N * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2]; + const int implicit_M = safe_cast( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1] * conv_params.oS[2], + "conv"); const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * C_per_group; @@ -1001,11 +1009,11 @@ void dispatch_conv_2D_gpu( } // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; + bool inp_large = (static_cast(conv_params.N) * conv_params.iS[0] * + conv_params.iS[1]) >= 4096; bool channels_large = (conv_params.C + conv_params.O) >= 256; - bool out_large = - (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; + bool out_large = (static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1]) >= 256; if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h index 850ec15be6..f559596b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -135,7 +135,7 @@ implicit_gemm_conv_2d( C += tid.z * N; B += c_col * K; - C += c_row * (N * params->groups) + c_col; + C += static_cast(c_row) * N * params->groups + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h index d2fbac0fc7..f2ccc1c03d 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h @@ -94,7 +94,7 @@ implicit_gemm_conv_3d( C += tid.z * N; B += c_col * K; - C += c_row * (N * params->groups) + c_col; + C += static_cast(c_row) * N * params->groups + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 1241f77357..38250f9b81 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -200,14 +200,14 @@ implicit_gemm_conv_2d_general( (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + + size_t offset_cm = static_cast(n) * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2]; STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = mma_op.Ctile.frag_at(i, j); - int offset = offset_cm + (j * mma_t::TN_stride); + size_t offset = offset_cm + (j * mma_t::TN_stride); constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..6ad41e2e38 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -67,8 +67,10 @@ array indices_or_default( } Shape shape(x.shape().begin(), x.shape().end() - 2); - int total = - std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); + int total = safe_cast( + std::reduce( + shape.begin(), shape.end(), int64_t{1}, std::multiplies{}), + "gather"); return reshape(arange(total, uint32, s), std::move(shape), s); } @@ -433,7 +435,7 @@ array unflatten( } } if (infer_idx >= 0) { - shape[infer_idx] = a.shape(ax) / size; + shape[infer_idx] = safe_cast(a.shape(ax) / size, "unflatten"); size *= shape[infer_idx]; } if (size != a.shape(ax)) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f3acec574b..62460a3d1b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2136,12 +2136,12 @@ bool Flatten::is_equivalent(const Primitive& other) const { Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) { Shape shape = input.shape(); - auto flat_size = input.shape(start_axis); + int64_t flat_size = input.shape(start_axis); for (int ax = start_axis + 1; ax <= end_axis; ++ax) { flat_size *= input.shape(ax); } shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1); - shape[start_axis] = flat_size; + shape[start_axis] = safe_cast(flat_size, "flatten"); return shape; } @@ -3913,7 +3913,7 @@ Shape Reshape::output_shape(const array& input, Shape shape) { // Infer the shape if (size > 0 && infer_idx >= 0) { - shape[infer_idx] = input.size() / size; + shape[infer_idx] = safe_cast(input.size() / size, "reshape"); size *= shape[infer_idx]; } else if (infer_idx >= 0) { throw std::invalid_argument( diff --git a/mlx/utils.h b/mlx/utils.h index 7835a97028..d8b4c7ac99 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -3,6 +3,10 @@ #pragma once #include +#include +#include +#include +#include #include #include "mlx/api.h" @@ -96,6 +100,24 @@ MLX_API Dtype result_type(const std::vector& arrays); MLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2); +template +inline ShapeElem safe_cast(T dim, std::string_view op = "") { + constexpr int64_t lo = std::numeric_limits::min(); + constexpr int64_t hi = std::numeric_limits::max(); + auto v = static_cast(dim); + if (v < lo || v > hi) { + std::ostringstream msg; + if (!op.empty()) { + msg << "[" << op << "] "; + } + msg << "Shape dimension " << v << " is outside the supported range [" << lo + << ", " << hi + << "]. MLX currently uses 32-bit integers for shape dimensions."; + throw std::overflow_error(msg.str()); + } + return static_cast(v); +} + /** * Returns the axis normalized to be in the range [0, ndim). */ diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index 58cca348e5..52b5a3f3a6 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -477,6 +477,38 @@ TEST_CASE("test gpu validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } +TEST_CASE("test gpu int32 shape overflow errors") { + // (2^30, 2).flatten() — product 2^31 doesn't fit in ShapeElem. + // Issue #2681 reported wrapped shape (-2147483648,) and a + // 2^64 - X reported size. The lazy graph is never evaluated. + auto a = zeros({1 << 30, 2}); + CHECK_THROWS_AS(flatten(a), std::overflow_error); + + // conv_general output > 2^31 elements with each per-dim < 2^31. + // Total elements 524290 * 64 * 64 = 2,147,491,840. + int n = static_cast((int64_t{1} << 31) / (64 * 64) + 2); + auto x = ones({n, 8, 8, 1}, float16); + auto w = ones({1, 1, 1, 1}, float16); + auto y = conv_general( + /* input = */ x, + /* weight = */ w, + /* stride = */ {1, 1}, + /* padding_lo = */ {0, 0}, + /* padding_hi = */ {0, 0}, + /* kernel_dilation = */ {1, 1}, + /* input_dilation = */ {9, 9}, + /* groups = */ 1, + /* flip = */ false); + CHECK_EQ(y.shape(), Shape{n, 64, 64, 1}); + + // reshape with inferred dim that won't fit in ShapeElem — issue #3327. + CHECK_THROWS_AS(reshape(y, {-1}), std::overflow_error); + + // take(a, idx) routes through an internal flatten — overflows on flatten. + auto idx = array({0u}, uint32); + CHECK_THROWS_AS(take(y, idx), std::overflow_error); +} + TEST_CASE("test memory info") { // Test cache limits {