Skip to content

Commit

Permalink
src: cpu: aarch64: re-enable fp16 post-ops for aarch64
Browse files Browse the repository at this point in the history
Perform the eltwise post-ops in fp32 instead of fp16 by
casting up (to fp32) before executing the eltwise op and
then casting back down (to fp16) after the operation completes.
With this change, all fp16 benchdnn tests pass on aarch64.
  • Loading branch information
fadara01 committed Mar 12, 2023
1 parent 01e8272 commit 4446cd3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
29 changes: 27 additions & 2 deletions src/cpu/aarch64/acl_post_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Arm Ltd. and affiliates
* Copyright 2022-2023 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@
* limitations under the License.
*******************************************************************************/

#include "common/float16.hpp"
#include "cpu/aarch64/acl_gemm_convolution.hpp"

namespace dnnl {
Expand Down Expand Up @@ -59,7 +60,31 @@ status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const {
= dynamic_cast<acl_eltwise_fwd_t *>(post_op.get());
if (eltwise_post_op == nullptr) return status::runtime_error;

CHECK(eltwise_post_op->execute_forward(ctx, src, src));
if (dst_data_type == data_type::f16) {
// in this case we want to cast the src tensor up to fp32
arm_compute::TensorInfo src_info
= eltwise_post_op->pd()->aep.data_info;
// new src tensor with fp32 datatype
arm_compute::Tensor src_tensor;
src_tensor.allocator()->init(src_info);
src_tensor.allocator()->allocate();
float *src_f32 = (float *)src_tensor.buffer();
// total_size gives the size in bytes, we divide by 4 because the src_tensor is fp32
size_t num_elements = src_tensor.info()->total_size() / 4;
// cast src up to fp32 and store the result in src_f32
cvt_float16_to_float(
src_f32, (dnnl::impl::float16_t *)src, num_elements);
// perform the operation in fp32
CHECK(eltwise_post_op->execute_forward(ctx, src_f32, src_f32));
// cast src_f32 down and store final result in src
cvt_float_to_float16(
(dnnl::impl::float16_t *)src, src_f32, num_elements);
src_tensor.allocator()->free();

} else {
CHECK(eltwise_post_op->execute_forward(ctx, src, src));
}

} else {
return status::runtime_error;
}
Expand Down
34 changes: 16 additions & 18 deletions src/cpu/aarch64/acl_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,8 @@ struct acl_post_ops_t {
status_t init(engine_t *engine, post_ops_t &post_ops,
const memory_desc_t &dst_md) {

// Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
// the post op in f32 and then casts down to f16 while ACL runs the post op in f16
// leading to a loss of accuracy compared to ref.
ACL_CHECK_SUPPORT(
post_ops.len() >= 1 && dst_md.data_type == data_type::f16,
"post ops cannot be executed in fp16");
CHECK(post_ops.set_default_formats(&dst_md));
dst_data_type = dst_md.data_type;

// Reset properties derived from post_ops
sum_index = -1;
Expand Down Expand Up @@ -105,8 +100,15 @@ struct acl_post_ops_t {
eltwise_desc.alg_kind = po.eltwise.alg;
eltwise_desc.alpha = po.eltwise.alpha;
eltwise_desc.beta = po.eltwise.beta;
eltwise_desc.src_desc = dst_md;
eltwise_desc.dst_desc = dst_md;
memory_desc_t temp_dst = dst_md;
// pass eltwise a desc with f32 datatype to perform the operation in fp32 rather than fp16
// since oneDNN requires all post-ops to run in fp32.
// we don't need to do that to the other post-ops as executing them in fp16 yields the same result.
if (dst_data_type == data_type::f16) {
temp_dst.data_type = data_type::f32;
}
eltwise_desc.src_desc = temp_dst;
eltwise_desc.dst_desc = temp_dst;
eltwise_desc.prop_kind = prop_kind_t::dnnl_forward;
auto empty_attr = dnnl_primitive_attr();
typename acl_eltwise_fwd_t::pd_t acl_eltwise_pd(
Expand Down Expand Up @@ -135,16 +137,12 @@ struct acl_post_ops_t {
const memory_desc_t &dst_md,
arm_compute::ActivationLayerInfo &act_info_to_fuse) {

// Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
// the post op in f32 and then casts down to f16 while ACL runs the post op in f16
// leading to a loss of accuracy compared to ref.
ACL_CHECK_SUPPORT(
base_post_ops.len() >= 1 && dst_md.data_type == data_type::f16,
"post ops cannot be executed in fp16");
CHECK(base_post_ops.set_default_formats(&dst_md));

// If the first entry is eltwise, we fuse it
if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise()) {
dst_data_type = dst_md.data_type;
// If the first entry is eltwise, we fuse it, except when the datatype
// is fp16 because in this case we want to execute the eltwise in fp32.
if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise()
&& dst_data_type != data_type::f16) {

const auto &first_po = base_post_ops.entry_[0].eltwise;
ACL_CHECK_SUPPORT(first_po.scale != 1.0f,
Expand Down Expand Up @@ -181,7 +179,7 @@ struct acl_post_ops_t {
private:
// Index of the sum post op if there is one, < 0 means no sum
int sum_index = -1;

data_type_t dst_data_type;
// Vector of primitives used to execute the post ops. They are constructed
// in init to be either acl_binary_t (for sum, add, sub, div, mul, min and
// max) or acl_eltwise_fwd_t (for relu, elu, tanh, square, abs etc)
Expand Down

0 comments on commit 4446cd3

Please sign in to comment.