Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);

const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
Expand Down Expand Up @@ -2207,16 +2207,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}

Expand Down Expand Up @@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
return;
}

if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
return;
}
Expand Down
12 changes: 9 additions & 3 deletions ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,21 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
}
}

bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {

bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
if (ggml_is_quantized(type)) {
return false;
}

if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
const size_t ts = ggml_type_size(type);
if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
return false;
}
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
if (src0_nb[i] % (2*ts) != 0) {
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this disables mmf for batch_size = 1. Is that expected?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before

Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
lfm2moe 8B.A1B F16 15.54 GiB 8.34 B CUDA 99 pp512 7456.65 ± 45.82
lfm2moe 8B.A1B F16 15.54 GiB 8.34 B CUDA 99 tg128 146.77 ± 0.08

after

Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes

model size params backend ngl test t/s
lfm2moe 8B.A1B F16 15.54 GiB 8.34 B CUDA 99 pp512 7405.42 ± 53.04
lfm2moe 8B.A1B F16 15.54 GiB 8.34 B CUDA 99 tg128 129.49 ± 0.68

if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct mmf_ids_data {

void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);

bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);

template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
Expand Down
8 changes: 7 additions & 1 deletion ggml/src/ggml-cuda/mmvf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -716,10 +716,16 @@ void ggml_cuda_op_mul_mat_vec_f(
GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
}

bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
const size_t ts = ggml_type_size(type);
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
if (src0_nb[i] % (2*ts) != 0) {
return false;
}
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmvf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);

bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);
2 changes: 2 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ llama_context::llama_context(
llama_context_params params) :
model(model),
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
// may need to be backend-dependent
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);

t_start_us = model.t_start_us;
Expand Down
44 changes: 18 additions & 26 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3377,11 +3377,11 @@ struct test_mul_mat : public test_case {
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions
const bool v; // whether a and b are non-contiguous views
const int64_t k_v; // size of k in memory, resulting in a non-contiguous view for k_v > k, no view for k_v == 0
const uint32_t o; // number of outputs

std::string vars() override {
return VARS_TO_STR10(type_a, type_b, m, n, k, bs, nr, per, v, o);
return VARS_TO_STR10(type_a, type_b, m, n, k, bs, nr, per, k_v, o);
}

double max_nmse_err() override {
Expand All @@ -3402,8 +3402,8 @@ struct test_mul_mat : public test_case {
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
std::array<int64_t, 4> per = {0, 1, 2, 3},
bool v = false, uint32_t o = 1)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v), o(o) {}
int64_t k_v = 0, uint32_t o = 1)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), k_v(k_v), o(o) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
Expand All @@ -3413,7 +3413,7 @@ struct test_mul_mat : public test_case {
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
if (npermuted > 0) {
GGML_ASSERT(npermuted == 2);
GGML_ASSERT(!v); // not handled
GGML_ASSERT(k_v == 0); // not handled
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);

Expand All @@ -3437,29 +3437,21 @@ struct test_mul_mat : public test_case {
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
if (v) {
a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]);
const int64_t k_physical = k_v == 0 ? k : k_v;
a = ggml_new_tensor_4d(ctx, type_a, k_physical, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k_physical, n, bs[0]*nr[0], bs[1]*nr[1]);

if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(a);
}
ggml_set_param(b);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(a);
}
ggml_set_param(b);
}

if (k_v != 0) {
GGML_ASSERT(k_v > k);
a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);

if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(a);
}
ggml_set_param(b);
}
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
Expand Down Expand Up @@ -6886,7 +6878,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, 64, 3));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));

#if 0
Expand All @@ -6912,7 +6904,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (uint32_t k = 0; k < 2; ++k) {
for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, bs2}, {nr, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, bs2}, {nr, 1}, {0, 1, 2, 3}, true));
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, bs2}, {nr, 1}, {0, 1, 2, 3}, 2*1056 + k));
}
}
}
Expand Down Expand Up @@ -7405,7 +7397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));

test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));

for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
Expand Down
Loading