diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5cd1e0d862db4..5a9e54721e463 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3639,9 +3639,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index cced0369d0226..523f9d71ba14e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c291b886cb7ee..5ab42f59e0635 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6567,6 +6567,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection) } for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {