Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlx/backend/cuda/conv/gemm_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"

#include <cooperative_groups.h>

Expand Down Expand Up @@ -136,8 +137,10 @@ void gemm_conv_nd(
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
int mat_M = check_shape_dim(
static_cast<int64_t>(out.size() / params.O), "conv"); // N * H_out * W_out
int mat_K = check_shape_dim(
static_cast<int64_t>(wt.size() / params.O), "conv"); // C * H_wt * W_wt
int mat_N = params.O; // O

// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
Expand Down
8 changes: 6 additions & 2 deletions mlx/backend/cuda/conv/gemm_grouped_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"

#include <cooperative_groups.h>

Expand Down Expand Up @@ -141,8 +142,11 @@ void gemm_grouped_conv_nd(
// Get gemm shapes.
int C_per_group = params.C / params.groups;
int O_per_group = params.O / params.groups;
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
int mat_M = check_shape_dim(
static_cast<int64_t>(out.size() / params.O), "conv"); // N * H_out * W_out
int mat_K = check_shape_dim(
static_cast<int64_t>(wt.size() / params.O),
"conv"); // C_per_group * H_wt * W_wt
int mat_N = O_per_group; // O_per_group

// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
Expand Down
36 changes: 24 additions & 12 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ void explicit_gemm_conv_ND_gpu(
array& out,
const MLXConvParams<N>& conv_params) {
// Get gemm shapes
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_M =
check_shape_dim(static_cast<int64_t>(out.size() / conv_params.O), "conv");
int implicit_K =
check_shape_dim(static_cast<int64_t>(wt.size() / conv_params.O), "conv");
int implicit_N = conv_params.O;
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K};
Expand Down Expand Up @@ -113,8 +115,10 @@ void explicit_gemm_conv_group_ND_gpu(
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Get gemm shapes
const int implicit_M = out.size() / conv_params.O;
const int implicit_K = wt.size() / conv_params.O;
const int implicit_M =
check_shape_dim(static_cast<int64_t>(out.size() / conv_params.O), "conv");
const int implicit_K =
check_shape_dim(static_cast<int64_t>(wt.size() / conv_params.O), "conv");
const int implicit_N = O_per_group;

int kernel_size = 1;
Expand Down Expand Up @@ -200,7 +204,10 @@ void implicit_gemm_conv_2D_gpu(
const int O_per_group = conv_params.O / conv_params.groups;

// Deduce implicit gemm size
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
const int implicit_M = check_shape_dim(
static_cast<int64_t>(conv_params.N) * conv_params.oS[0] *
conv_params.oS[1],
"conv");
const int implicit_N = O_per_group;
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;

Expand Down Expand Up @@ -329,7 +336,10 @@ void implicit_gemm_conv_2D_general_gpu(
array& out,
const MLXConvParams<2>& conv_params) {
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_M = check_shape_dim(
static_cast<int64_t>(conv_params.N) * conv_params.oS[0] *
conv_params.oS[1],
"conv");
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;

Expand Down Expand Up @@ -512,8 +522,10 @@ void implicit_gemm_conv_3D_gpu(
const int O_per_group = conv_params.O / conv_params.groups;

// Deduce implicit gemm size
const int implicit_M =
conv_params.N * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2];
const int implicit_M = check_shape_dim(
static_cast<int64_t>(conv_params.N) * conv_params.oS[0] *
conv_params.oS[1] * conv_params.oS[2],
"conv");
const int implicit_N = O_per_group;
const int implicit_K =
conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * C_per_group;
Expand Down Expand Up @@ -1001,11 +1013,11 @@ void dispatch_conv_2D_gpu(
}

// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
bool inp_large = (static_cast<int64_t>(conv_params.N) * conv_params.iS[0] *
conv_params.iS[1]) >= 4096;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
bool out_large =
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
bool out_large = (static_cast<int64_t>(conv_params.N) * conv_params.oS[0] *
conv_params.oS[1]) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ implicit_gemm_conv_2d(
C += tid.z * N;

B += c_col * K;
C += c_row * (N * params->groups) + c_col;
C += static_cast<size_t>(c_row) * size_t(N * params->groups) + size_t(c_col);

const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ implicit_gemm_conv_3d(
C += tid.z * N;

B += c_col * K;
C += c_row * (N * params->groups) + c_col;
C += static_cast<size_t>(c_row) * size_t(N * params->groups) + size_t(c_col);

const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,18 @@ implicit_gemm_conv_2d_general(
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;

if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
size_t offset_cm = static_cast<size_t>(n) *
static_cast<size_t>(params->out_strides[0]) +
static_cast<size_t>(oh) *
static_cast<size_t>(params->out_strides[1]) +
static_cast<size_t>(ow) *
static_cast<size_t>(params->out_strides[2]);

STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
int offset = offset_cm + (j * mma_t::TN_stride);
size_t offset = offset_cm + (j * mma_t::TN_stride);

constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;

Expand Down
9 changes: 6 additions & 3 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ array indices_or_default(
}

Shape shape(x.shape().begin(), x.shape().end() - 2);
int total =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
int total = check_shape_dim(
std::reduce(
shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>{}),
"gather");
return reshape(arange(total, uint32, s), std::move(shape), s);
}

Expand Down Expand Up @@ -433,7 +435,8 @@ array unflatten(
}
}
if (infer_idx >= 0) {
shape[infer_idx] = a.shape(ax) / size;
shape[infer_idx] =
check_shape_dim(static_cast<int64_t>(a.shape(ax) / size), "unflatten");
size *= shape[infer_idx];
}
if (size != a.shape(ax)) {
Expand Down
7 changes: 4 additions & 3 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2136,12 +2136,12 @@ bool Flatten::is_equivalent(const Primitive& other) const {

Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) {
Shape shape = input.shape();
auto flat_size = input.shape(start_axis);
int64_t flat_size = input.shape(start_axis);
for (int ax = start_axis + 1; ax <= end_axis; ++ax) {
flat_size *= input.shape(ax);
}
shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);
shape[start_axis] = flat_size;
shape[start_axis] = check_shape_dim(flat_size, "flatten");
return shape;
}

Expand Down Expand Up @@ -3913,7 +3913,8 @@ Shape Reshape::output_shape(const array& input, Shape shape) {

// Infer the shape
if (size > 0 && infer_idx >= 0) {
shape[infer_idx] = input.size() / size;
shape[infer_idx] =
check_shape_dim(static_cast<int64_t>(input.size() / size), "reshape");
size *= shape[infer_idx];
} else if (infer_idx >= 0) {
throw std::invalid_argument(
Expand Down
20 changes: 20 additions & 0 deletions mlx/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#pragma once

#include <exception>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <string_view>
#include <variant>

#include "mlx/api.h"
Expand Down Expand Up @@ -96,6 +100,22 @@ MLX_API Dtype result_type(const std::vector<array>& arrays);

MLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2);

inline ShapeElem check_shape_dim(int64_t dim, std::string_view op = "") {
constexpr int64_t lo = std::numeric_limits<ShapeElem>::min();
constexpr int64_t hi = std::numeric_limits<ShapeElem>::max();
if (dim < lo || dim > hi) {
std::ostringstream msg;
if (!op.empty()) {
msg << "[" << op << "] ";
}
msg << "Shape dimension " << dim << " is outside the supported range ["
<< lo << ", " << hi
<< "]. MLX currently uses 32-bit integers for shape dimensions.";
throw std::overflow_error(msg.str());
}
return static_cast<ShapeElem>(dim);
}

/**
* Returns the axis normalized to be in the range [0, ndim).
*/
Expand Down
96 changes: 96 additions & 0 deletions tests/gpu_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <array>

#include "doctest/doctest.h"
#include "mlx/backend/gpu/device_info.h"
#include "mlx/mlx.h"

using namespace mlx::core;
Expand Down Expand Up @@ -477,6 +478,101 @@ TEST_CASE("test gpu validation") {
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
}

TEST_CASE("test gpu int32 shape overflow errors") {
// (2^30, 2).flatten() — product 2^31 doesn't fit in ShapeElem.
// Issue #2681 reported wrapped shape (-2147483648,) and a
// 2^64 - X reported size. The lazy graph is never evaluated.
auto a = zeros({1 << 30, 2});
CHECK_THROWS_AS(flatten(a), std::overflow_error);

// conv_general output > 2^31 elements with each per-dim < 2^31.
// Total elements 524290 * 64 * 64 = 2,147,491,840.
int n = static_cast<int>((int64_t{1} << 31) / (64 * 64) + 2);
auto x = ones({n, 8, 8, 1}, float16);
auto w = ones({1, 1, 1, 1}, float16);
auto y = conv_general(
/* input = */ x,
/* weight = */ w,
/* stride = */ {1, 1},
/* padding_lo = */ {0, 0},
/* padding_hi = */ {0, 0},
/* kernel_dilation = */ {1, 1},
/* input_dilation = */ {9, 9},
/* groups = */ 1,
/* flip = */ false);
CHECK_EQ(y.shape(), Shape{n, 64, 64, 1});

// reshape with inferred dim that won't fit in ShapeElem — issue #3327.
CHECK_THROWS_AS(reshape(y, {-1}), std::overflow_error);

// take(a, idx) routes through an internal flatten — overflows on flatten.
auto idx = array({0u}, uint32);
CHECK_THROWS_AS(take(y, idx), std::overflow_error);

// The conv dispatcher refuses to compute a >2^31-element output. eval
// allocates the ~4 GB float16 output before the dispatcher check fires,
// so skip on small-GPU devices.
size_t needed = size_t(n) * 64 * 64 * sizeof(float16_t);
auto max_buf = std::get<size_t>(gpu::device_info().at("max_buffer_length"));
if (max_buf >= needed) {
CHECK_THROWS_AS(eval(y), std::overflow_error);
}
}

TEST_CASE("test gpu conv2d large output offset") {
// Regression for the kernel-offset half of #3327 (originally PR #3294).
// Output shape (batch, 64, 64, O) with batch * 64 * 64 * O > 2^31 but
// each per-dim and `batch * 64 * 64` fit in int32 — so the dispatcher
// accepts the work but each thread's output offset `c_row * O + c_col`
// exceeds int32 max. Before the size_t promotion in
// steel_conv_general.h, threads wrote to wrapped offsets and the last
// batches read back zeros.
constexpr int H = 64;
constexpr int W = 64;
constexpr int O = 17;
const int per_batch_output = H * W * O;
const int batch_size =
static_cast<int>((int64_t{1} << 31) / per_batch_output + 2);

// Skip if the output array (~4.3 GB fp16) won't fit on this device.
size_t needed = size_t(batch_size) * H * W * O * sizeof(float16_t);
auto max_buf = std::get<size_t>(gpu::device_info().at("max_buffer_length"));
if (max_buf < needed) {
return;
}

auto batch_values =
astype(remainder(arange(batch_size, int32), array(251)), float16);
batch_values = reshape(batch_values, {batch_size, 1, 1, 1});
auto x = multiply(ones({batch_size, H, W, 1}, float16), batch_values);
auto channel_values =
divide(arange(1.0, double(O + 1), float16), array(8.0f, float16));
auto w = reshape(channel_values, {O, 1, 1, 1});

auto y = conv2d(x, w);

// Expected y[i, h, w, j] = (i % 251) * ((j+1)/8). Spot check first and
// last batches; the last batch covers offsets past int32 max.
auto expected_first = multiply(
slice(x, {0, 0, 0, 0}, {1, H, W, 1}),
reshape(channel_values, {1, 1, 1, O}));
auto expected_last = multiply(
slice(x, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, 1}),
reshape(channel_values, {1, 1, 1, O}));
CHECK(allclose(
slice(y, {0, 0, 0, 0}, {1, H, W, O}),
expected_first,
/* rtol = */ 1e-3,
/* atol = */ 1e-3)
.item<bool>());
CHECK(allclose(
slice(y, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, O}),
expected_last,
/* rtol = */ 1e-3,
/* atol = */ 1e-3)
.item<bool>());
}

TEST_CASE("test memory info") {
// Test cache limits
{
Expand Down