Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions ggml/src/ggml-sycl/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,28 @@ class DnnlGemmWrapper {
else static_assert(0);
}

// matrix A has m rows, k columns
// matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {

auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);

// { # strides, # rows, # columns }
dnnl::memory::dims a_dims = { batches_a, m, k };
dnnl::memory::dims b_dims = { batches_b, k, n };
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };

// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };

dnnl::memory::dims a_dims = {batches_a, m, k };
dnnl::memory::dims a_strides = {stra2, stra1, stra0};
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);

dnnl::memory::dims b_dims = {batches_b, k, n };
dnnl::memory::dims b_strides = {strb2, strb0, strb1};
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);

dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
dnnl::memory::dims c_strides = {m*n, 1, m };
const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

#ifdef GGML_SYCL_F16
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
#endif
Expand All @@ -76,24 +65,23 @@ class DnnlGemmWrapper {

auto scratchpad_md = matmul_pd.scratchpad_desc();
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);

auto matmul_prim = dnnl::matmul(matmul_pd);

std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });

matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });

matmul_prim.execute(stream, matmul_args);
}

// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {

gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
};

Expand Down
165 changes: 119 additions & 46 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32(

static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
const sycl::nd_item<3> &item_ct1) {

const sycl::half *x = (const sycl::half *)vx;
Expand All @@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
item_ct1.get_local_id(0);
const int channel_x = channel / channel_x_divisor;

const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;

Expand All @@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const int row_y = col_x;

const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel*nrows_y + row_y;
const int iy = channel * channel_stride_y + row_y;

const float xi =
sycl::vec<sycl::half, 1>(x[ix])
Expand Down Expand Up @@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
const void *vx, const float *y, float *dst, const int ncols_x,
const int nrows_x, const int row_stride_x, const int nchannels_x,
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {

const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
Expand All @@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
row_stride_x, channel_stride_x,
row_stride_x, channel_stride_x, channel_stride_y,
nchannels_y / nchannels_x, item_ct1);
});
}
Expand Down Expand Up @@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl(

#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
Expand Down Expand Up @@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl(

#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
Expand Down Expand Up @@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
const int64_t nb02 = src0->nb[2];

const int64_t ne12 = src1->ne[2];
const int64_t nb11 = src1->nb[1];

SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();
Expand All @@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml

const int64_t row_stride_x = nb01 / sizeof(sycl::half);
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
const int64_t channel_stride_y = nb11 / sizeof(float);

ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down Expand Up @@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
float * dst_ddf = static_cast<float *>(dst->data);

const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
const size_t type_size_src0 = ggml_type_size(src0->type);
const size_t type_size_src1 = ggml_type_size(src1->type);
GGML_ASSERT(nb10 == type_size_src1);

// SRC1 strides
int64_t s11 = nb11 / type_size_src1;
Expand All @@ -2854,11 +2855,32 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
if (src1->type != GGML_TYPE_F16) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
" : converting src1 to fp16");
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
const int64_t ne_src1 = ggml_nelements(src1);

// iterate tensor dims and find the slowest moving dim and stride
int64_t last_dim=0;
int64_t last_str=0;
int64_t largest_str=0;
for(int i = 0; i< 4; i++){
// last stride is always the largest
if(src1->nb[i] == largest_str){
if(src1->ne[last_dim] == 1){
last_str = i;
last_dim = i;
}
}
if(src1->nb[i] > largest_str){
largest_str = src1->nb[i];
last_str = i;
last_dim = i;
}

}
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
src1_f16_alloc.alloc(ne_src1);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);

const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
GGML_ASSERT(to_fp16_sycl != nullptr);
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);

src1_f16 = src1_f16_alloc.get();
s11 = ne10;
Expand Down Expand Up @@ -2892,38 +2914,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons

#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {

DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};

if (r2 == 1 && r3 == 1) {
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
}
else {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
int64_t str_a0 = nb00 / type_size_src0;
int64_t str_a1 = nb01 / type_size_src0;
int64_t str_a2 = nb02 / type_size_src0;

int64_t str_b0 = nb10 / type_size_src1;
int64_t str_b1 = nb11 / type_size_src1;
int64_t str_b2 = nb12 / type_size_src1;

auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
const sycl::half *src1, float *dst,
int64_t a0, int64_t a1, int64_t batcha,
int64_t b0, int64_t b1, int64_t batchb,
int64_t sa0, int64_t sa1, int64_t sa2,
int64_t sb0, int64_t sb1, int64_t sb2,
int64_t sd2) {
bool supported_broadcast = batchb == batcha ? true
: batchb == 1 || batcha == 1 ? true
: false;
if (supported_broadcast) {
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
} else {
// iterate over batches from smaller set of matrices (matrix 0)
int64_t batches0 = batcha;
int64_t batches1 = batchb;

if (batches0 > batches1) {
int64_t num_mul_mats = batches1;
int64_t sub_batch = batches0 / num_mul_mats;
// src0 is batched and bigger, shift and multiply with src1
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
const sycl::half *src1_shifted = src1 + (sb2 * i0);
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, sub_batch, 1);
}
} else {
int64_t num_mul_mats = batches0;
int64_t sub_batch = batches1 / num_mul_mats;
// src1 is batched and bigger, shift and multiply with src0
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
const sycl::half *src0_shifted = src0 + (sa2 * i1);
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, 1, sub_batch);
}
}
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
};

bool cont_batches_a = nb02 * ne02 == nb03;
bool cont_batches_b = nb12 * ne12 == nb13;
if (cont_batches_a && cont_batches_b) {
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
str_b2, nb2 / sizeof(float));
} else {
for (int64_t b_a = 0; b_a < ne03; b_a++) {
const sycl::half *src0_f16_shifted
= src0_f16 + (nb03 * b_a / type_size_src0);
const sycl::half *src1_f16_shifted
= src1_f16 + (nb13 * b_a / type_size_src1);
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
int64_t batches0 = ne02;
int64_t batches1 = ne12;
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
}
}
}

}
else
#endif
Expand Down Expand Up @@ -3263,10 +3336,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
// KQ + KQV multi-batch
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
Expand Down
Loading