Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class matmul_desc_t {
/// scale_type==float && a_type==float && b_type==float && c_type==float.
/// Currently, this function only supports beta==0 or beta==1.
/// Currently, this function only supports the relu, bias, gelu, gelu_bias,
/// gelu_aux, gelu_aux_bias and dgelu epilogue.
/// gelu_aux, gelu_aux_bias, dgelu and bgradb epilogue.
/// NOTE: Non-col-major matrix will be converted to col-major matrix before.
/// TODO: Impl row-major matmul without layout conversion.
/// multiplication and converted back after multiplication.
Expand Down Expand Up @@ -331,10 +331,12 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
compute_desc->_epilogue != epilogue_t::gelu_bias &&
compute_desc->_epilogue != epilogue_t::gelu_aux &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias &&
compute_desc->_epilogue != epilogue_t::dgelu) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
"gelu_aux_bias and dgelu epilogue currently.");
compute_desc->_epilogue != epilogue_t::dgelu &&
compute_desc->_epilogue != epilogue_t::bgradb) {
throw std::runtime_error(
"dpct::blas_gemm::experimental::matmul() only "
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
"gelu_aux_bias, dgelu and bgradb epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -559,6 +561,28 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
#endif
}

::dnnl::memory *po_bias_bgradb_mem = nullptr;
auto po_bias_bgradb_md = ::dnnl::memory::desc(
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
? ::dnnl::memory::dims{N, 1}
: ::dnnl::memory::dims{1, N},
dpct::dnnl::memory_desc_ext::to_dnnl_data_type(
compute_desc->_bias_data_type),
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
? ::dnnl::memory::dims{1, N}
: ::dnnl::memory::dims{N, 1});
if (compute_desc->_epilogue == epilogue_t::bgradb) {
po_bias_bgradb_mem = new ::dnnl::memory(
po_bias_bgradb_md, handle->get_engine(), DNNL_MEMORY_NONE);
#ifdef DPCT_USM_LEVEL_NONE
detail::type_dispatch<detail::set_buffer_impl>(
compute_desc->_bias_data_type, po_bias_bgradb_mem,
compute_desc->_bias_pointer);
#else
po_bias_bgradb_mem->set_data_handle(compute_desc->_bias_pointer);
#endif
}

::dnnl::memory *po_aux_mem = nullptr;
auto po_aux_md = ::dnnl::memory::desc(
::dnnl::memory::dims{M, N},
Expand Down Expand Up @@ -660,6 +684,17 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
post_op_prim_event =
::dnnl::sycl_interop::execute(dgelu_prim, handle->get_engine_stream(),
dgelu_args, {matmul_prim_event});
} else if (compute_desc->_epilogue == epilogue_t::bgradb) {
auto reduction_pd = ::dnnl::reduction::primitive_desc(
handle->get_engine(), ::dnnl::algorithm::reduction_sum, weights_md,
po_bias_bgradb_md, 0.f, 0.f);
auto reduction_prim = ::dnnl::reduction(reduction_pd);
std::unordered_map<int, ::dnnl::memory> reduction_args;
reduction_args.insert({DNNL_ARG_SRC, *weights_mem});
reduction_args.insert({DNNL_ARG_DST, *po_bias_bgradb_mem});
post_op_prim_event = ::dnnl::sycl_interop::execute(
reduction_prim, handle->get_engine_stream(), reduction_args,
{matmul_prim_event});
}

// end of calling oneDNN
Expand Down Expand Up @@ -700,6 +735,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
delete dst_mem;
if (po_bias_mem)
delete po_bias_mem;
if (po_bias_bgradb_mem)
delete po_bias_bgradb_mem;
if (po_aux_mem)
delete po_aux_mem;
::dpct::cs::free((void *)new_a, *q_ptr);
Expand Down