Skip to content

Commit

Permalink
cpu: aarch64: matmul: Optimize (A^T)*(B^T) in acl_matmul
Browse files Browse the repository at this point in the history
This PR computes (B*A)^T instead of (A^T)*(B^T) when the
cost of transposing (B*A) is cheaper. This improves performance
by ~1.25x for square matrices and even higher for
tall-skinny/fat-short matrices.

It also reduces code duplication and moves allocation of
dst accumulator from ACL to scratchpad memory in oneDNN.
  • Loading branch information
annop-w authored and dzarukin committed May 5, 2024
1 parent a986231 commit 95c00ed
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 205 deletions.
83 changes: 21 additions & 62 deletions src/cpu/aarch64/matmul/acl_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Arm Ltd. and affiliates
* Copyright 2021-2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,20 +24,24 @@ namespace matmul {

using namespace data_type;

status_t acl_matmul_t::execute_forward_non_fixed_format(
const exec_ctx_t &ctx) const {
template <bool IsFixedFormat>
status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const {

status_t status = status::success;
auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);

bool is_transA = pd()->amp_.is_transA;
bool is_transB = pd()->amp_.is_transB;
bool do_transC = pd()->amp_.do_transC;
bool use_dst_acc = pd()->amp_.use_dst_acc;

std::lock_guard<std::mutex> _lock {this->mtx};
auto *acl_resource = ctx.get_resource_mapper()->get<acl_resource_t>(this);
acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj();

const auto scratchpad = ctx.get_scratchpad_grantor();

// Run transpose kernel
if (is_transA && !is_transB) {
acl_obj.src_tensor.allocator()->allocate();
Expand All @@ -53,7 +57,7 @@ status_t acl_matmul_t::execute_forward_non_fixed_format(
acl_obj.transB.run();
acl_obj.src_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
} else if (is_transA && is_transB) {
} else if (is_transA && is_transB && !do_transC) {
acl_obj.src_tensor.allocator()->allocate();
acl_obj.src_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
Expand All @@ -67,19 +71,20 @@ status_t acl_matmul_t::execute_forward_non_fixed_format(
const_cast<data_t *>(src_base));
acl_obj.wei_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
if (do_transC) { acl_obj.dst_acc_tensor.allocator()->allocate(); }
}

if (use_dst_acc) {
// Put the result in a new tensor, it will be accumulated to the dst
// during the post ops
acl_obj.dst_tensor.allocator()->allocate();
} else {
auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
acl_obj.dst_tensor.allocator()->import_memory(dst_base);
}
// Put the result in a new tensor, if we have a sum post op.
// Result will be accumulated to the dst during the post ops.
auto dst_base = use_dst_acc ? scratchpad.get<void>(
memory_tracking::names::key_matmul_dst_in_acc_dt)
: CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
acl_obj.dst_tensor.allocator()->import_memory(dst_base);

acl_obj.gemm.run();

if (do_transC) { acl_obj.transC.run(); }

acl_obj.src_tensor.allocator()->free();
acl_obj.wei_tensor.allocator()->free();
if (is_transA) acl_obj.src_acc_tensor.allocator()->free();
Expand All @@ -93,56 +98,10 @@ status_t acl_matmul_t::execute_forward_non_fixed_format(
return status;
}

status_t acl_matmul_t::execute_forward_fixed_format(
const exec_ctx_t &ctx) const {

status_t status = status::success;
auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);

bool is_transA = pd()->amp_.is_transA;
bool use_dst_acc = pd()->amp_.use_dst_acc;

std::lock_guard<std::mutex> _lock {this->mtx};
auto *acl_resource = ctx.get_resource_mapper()->get<acl_resource_t>(this);
acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj();
// Run transpose kernel
if (is_transA) {
acl_obj.src_tensor.allocator()->allocate();
acl_obj.src_acc_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
acl_obj.transA.run();
acl_obj.wei_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
} else {
acl_obj.src_tensor.allocator()->import_memory(
const_cast<data_t *>(src_base));
acl_obj.wei_tensor.allocator()->import_memory(
const_cast<data_t *>(wei_base));
}

if (use_dst_acc) {
// Put the result in a new tensor, it will be accumulated to the dst
// during the post ops
acl_obj.dst_tensor.allocator()->allocate();
} else {
auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
acl_obj.dst_tensor.allocator()->import_memory(dst_base);
}

acl_obj.gemm.run();

acl_obj.src_tensor.allocator()->free();
acl_obj.wei_tensor.allocator()->free();
if (is_transA) acl_obj.src_acc_tensor.allocator()->free();

void *dst = acl_obj.dst_tensor.buffer();
pd()->post_ops.execute(ctx, dst);

acl_obj.dst_tensor.allocator()->free();

return status;
}
template status_t acl_matmul_t::execute_forward<true>(
const exec_ctx_t &ctx) const;
template status_t acl_matmul_t::execute_forward<false>(
const exec_ctx_t &ctx) const;

} // namespace matmul
} // namespace aarch64
Expand Down
62 changes: 43 additions & 19 deletions src/cpu/aarch64/matmul/acl_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,37 @@ struct acl_resource_t : public resource_t {
acl_obj_->src_tensor.allocator()->init(amp.src_tensor_info);
acl_obj_->wei_tensor.allocator()->init(amp.wei_tensor_info);
acl_obj_->dst_tensor.allocator()->init(amp.dst_tensor_info);
// Configure transpose kernel for src, wei or both
if (amp.is_transA) {

// Configure transpose kernel for src, wei, or dst
if (amp.is_transA && !amp.do_transC) {
acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info);
acl_obj_->transA.configure(
&acl_obj_->src_acc_tensor, &acl_obj_->src_tensor);
}

if (weights_format_kind_ != format_kind::any) {
if (amp.is_transB) {
acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info);
acl_obj_->transB.configure(
&acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor);
}
if (amp.is_transB && !amp.do_transC) {
acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info);
acl_obj_->transB.configure(
&acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor);
}

if (amp.do_transC) {
acl_obj_->dst_acc_tensor.allocator()->init(amp.dst_acc_info);
acl_obj_->transC.configure(
&acl_obj_->dst_acc_tensor, &acl_obj_->dst_tensor);
}

// Configure GEMM
acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor,
nullptr, &acl_obj_->dst_tensor, 1.0f, 0.0f, amp.gemm_info);
if (amp.do_transC) {
acl_obj_->gemm.configure(&acl_obj_->wei_tensor,
&acl_obj_->src_tensor, nullptr, &acl_obj_->dst_acc_tensor,
1.0f, 0.0f, amp.gemm_info);
} else {
acl_obj_->gemm.configure(&acl_obj_->src_tensor,
&acl_obj_->wei_tensor, nullptr, &acl_obj_->dst_tensor, 1.0f,
0.0f, amp.gemm_info);
}

return status::success;
}
acl_matmul_obj_t &get_acl_obj() const { return *acl_obj_; }
Expand Down Expand Up @@ -109,15 +123,15 @@ struct acl_matmul_t : public primitive_t {
VERBOSE_RUNTIMEDIM_UNSUPPORTED);

if (weights_format_kind_ == format_kind::any) {
CHECK(acl_matmul_utils::init_conf_matmul_fixed_format(
CHECK(acl_matmul_utils::init_conf_matmul<true>(
amp_, src_md_, weights_md_, dst_md_, *desc(), *attr()));
} else {
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
// to avoid seg. fault in case threadpool is enabled and its pointer is null
if (threadpool_utils::get_active_threadpool() == nullptr)
return status::unimplemented;
#endif
CHECK(acl_matmul_utils::init_conf_matmul_non_fixed_format(
CHECK(acl_matmul_utils::init_conf_matmul<false>(
amp_, src_md_, weights_md_, dst_md_, *desc(), *attr()));
}

Expand All @@ -127,9 +141,19 @@ struct acl_matmul_t : public primitive_t {
amp_.use_dst_acc = post_ops.has_sum();

// Validate ACL GEMM
ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&amp_.src_tensor_info,
&amp_.wei_tensor_info, nullptr, &amp_.dst_tensor_info,
amp_.alpha, 0.0f, amp_.gemm_info));
if (amp_.do_transC) {
ACL_CHECK_VALID(arm_compute::NEGEMM::validate(
&amp_.wei_tensor_info, &amp_.src_tensor_info, nullptr,
&amp_.dst_acc_info, amp_.alpha, 0.0f, amp_.gemm_info));
} else {
ACL_CHECK_VALID(arm_compute::NEGEMM::validate(
&amp_.src_tensor_info, &amp_.wei_tensor_info, nullptr,
&amp_.dst_tensor_info, amp_.alpha, 0.0f,
amp_.gemm_info));
}

auto scratchpad = scratchpad_registry().registrar();
CHECK(acl_matmul_utils::init_scratchpad(scratchpad, amp_, dst_md_));

return status::success;
}
Expand Down Expand Up @@ -166,17 +190,17 @@ struct acl_matmul_t : public primitive_t {

status_t execute(const exec_ctx_t &ctx) const override {
if (pd()->weights_format_kind_ == format_kind::any) {
return execute_forward_fixed_format(ctx);
return execute_forward<true>(ctx);
} else {
return execute_forward_non_fixed_format(ctx);
return execute_forward<false>(ctx);
}
}

private:
// To guard the const execute_forward(), the mutex must be 'mutable'
mutable std::mutex mtx;
status_t execute_forward_non_fixed_format(const exec_ctx_t &ctx) const;
status_t execute_forward_fixed_format(const exec_ctx_t &ctx) const;
template <bool IsFixedFormat>
status_t execute_forward(const exec_ctx_t &ctx) const;

const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
}; // acl_matmul_t
Expand Down
Loading

0 comments on commit 95c00ed

Please sign in to comment.