diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c01b98ac78f5a..5490173c99e9a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1528,7 +1528,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // supports 3D: a->ne[2] == b->ne[1] + // supports 4D a: + // a [n_embd, ne1, ne2, ne3] + // b I32 [n_rows, ne2, ne3, 1] + // + // return [n_embd, n_rows, ne2, ne3] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, // data diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0c01eb6fa8359..eafcefca8db1b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { + // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868 + if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) { + return false; + } switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f35c337952ec3..50dc1aa24fff5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == b->ne[2]); GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32);