Skip to content

Commit

Permalink
cpu: aarch64: allow blocked weight format for primitive creation
Browse files Browse the repository at this point in the history
with weights pre-packing enabled in torch.compile(),
the weights come already reorderd and in oneDNN format,
so, allowing format_kind::blocked as one of the supported
formats for acl inner product primitive.
  • Loading branch information
snadampal authored and vpirogov committed Mar 13, 2024
1 parent 0c922e0 commit 8aacc8f
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions src/cpu/aarch64/acl_inner_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,25 @@ struct acl_inner_product_fwd_t : public primitive_t {

status_t init(engine_t *engine) {
using namespace data_type;
const format_kind_t weights_format_kind_received
= weights_md_.format_kind;
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f16);
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
const bool is_weights_md_format_ok
= utils::one_of(weights_format_kind_received,
format_kind::any, format_kind::blocked);
const bool ok = is_fwd() && !has_zero_dim_memory()
&& utils::one_of(true, is_fp16_ok, is_fp32_ok)
&& weights_md_.format_kind == format_kind::any
&& set_default_params() == status::success;
&& is_weights_md_format_ok
&& set_default_params(true) == status::success;

if (!ok) return status::unimplemented;

CHECK(init_conf_ip(engine));
CHECK(init_conf_ip(engine, weights_format_kind_received));

return status::success;
}
Expand All @@ -115,7 +120,8 @@ struct acl_inner_product_fwd_t : public primitive_t {

acl_post_ops_t post_ops;

status_t init_conf_ip(engine_t *engine) {
status_t init_conf_ip(
engine_t *engine, format_kind_t weights_format_kind_received) {

ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims,
"source and weights dimensions must match");
Expand Down Expand Up @@ -257,10 +263,19 @@ struct acl_inner_product_fwd_t : public primitive_t {
return status::unimplemented;
}

const memory_desc_t weights_md_received = weights_md_;
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
weights_md_, expected_weight_format, inner_dim, o_dim,
remaining_dims, {});

ACL_CHECK_SUPPORT(
(weights_format_kind_received == format_kind::blocked)
&& !(dnnl_memory_desc_equal(
&weights_md_received, &weights_md_)),
"specific blocked format not supported by ACL, use "
"format_kind_t::any to find a supported blocked format for "
"your platform");

// clang-format off

// Validate fully connected layer manually to check for return status
Expand Down

0 comments on commit 8aacc8f

Please sign in to comment.