Skip to content

Commit

Permalink
WIP pipe fused params through subgraph and operators
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 460321945
  • Loading branch information
ngzhian authored and xnnpack-bot committed Jul 22, 2022
1 parent 62a70d5 commit b9664e9
Show file tree
Hide file tree
Showing 16 changed files with 907 additions and 21 deletions.
12 changes: 12 additions & 0 deletions BUILD.bazel
Expand Up @@ -13014,6 +13014,18 @@ xnnpack_unit_test(
],
)

xnnpack_unit_test(
name = "f32_fused_gemm_minmax_test",
srcs = [
"test/f32-fused-gemm-minmax.cc",
],
deps = MICROKERNEL_TEST_DEPS + [
":XNNPACK",
":gemm_microkernel_tester",
":jit_test_mode",
],
)

xnnpack_unit_test(
name = "f32_vhswish_test",
srcs = [
Expand Down
44 changes: 44 additions & 0 deletions include/xnnpack.h
Expand Up @@ -1656,6 +1656,50 @@ enum xnn_status xnn_create_convolution2d_nhwc_f32(
xnn_caches_t caches,
xnn_operator_t* convolution_op_out);

// Operators that can be fused (currently only into convolution).
enum xnn_fused_operator_type {
xnn_fused_operator_type_abs,
xnn_fused_operator_type_add,
xnn_fused_operator_type_negate,
xnn_fused_operator_type_hardswish,
};

// Struct representing a fused operator and its associated data. For example, an addition with constant will specify the
// constant in the arg field.
struct xnn_fused_operator {
enum xnn_fused_operator_type op_type;
float arg;
};

// Create a convolution operator with a number of fused operations that are applied after the convolution and clamping
// (with output_min and output_max).
// This is only supported on JIT platforms, as code to apply fused operations are generated by JIT.
enum xnn_status xnn_create_convolution2d_nhwc_f32_fused(
uint32_t input_padding_top,
uint32_t input_padding_right,
uint32_t input_padding_bottom,
uint32_t input_padding_left,
uint32_t kernel_height,
uint32_t kernel_width,
uint32_t subsampling_height,
uint32_t subsampling_width,
uint32_t dilation_height,
uint32_t dilation_width,
uint32_t groups,
size_t group_input_channels,
size_t group_output_channels,
size_t input_channel_stride,
size_t output_channel_stride,
const float* kernel,
const float* bias,
float output_min,
float output_max,
size_t num_fused_operations,
struct xnn_fused_operator* fused_operators,
uint32_t flags,
xnn_caches_t caches,
xnn_operator_t* convolution_op_out);

enum xnn_status xnn_setup_convolution2d_nhwc_f32(
xnn_operator_t convolution_op,
size_t batch_size,
Expand Down
88 changes: 82 additions & 6 deletions src/f32-gemm/upto6x8-aarch64-neonfma-cortex-a75.cc
Expand Up @@ -18,7 +18,9 @@ namespace {
class Generator : public Assembler {
using Assembler::Assembler;
public:
void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max);
void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params);
void perform_fused_ops(VRegister dst, VRegister src, size_t num_fused_operators,
const xnn_fused_operator* fused_operators);
};

// void xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_prfm_cortex_a75(
Expand Down Expand Up @@ -68,13 +70,62 @@ class Generator : public Assembler {
// C v30 v31
// Clamp v6 v7

void Generator::perform_fused_ops(VRegister dst, VRegister src, size_t num_fused_operators,
const xnn_fused_operator* fused_operators) {
mov(x15, x8); // backup x8
for (size_t i = 0; i < num_fused_operators; i++) {
switch (fused_operators[i].op_type) {
case xnn_fused_operator_type_add: {
ld1r({v6.v4s()}, mem[x8]);
fadd(dst, src, v6.v4s());
break;
}
case xnn_fused_operator_type_abs: {
fabs(dst, src);
break;
}
case xnn_fused_operator_type_negate: {
fneg(dst, src);
break;
}
case xnn_fused_operator_type_hardswish: {
auto sixth = v6.v4s();
auto three = v7.v4s();
auto six = v8.v4s();
auto zero = v9.v4s();
// src and dst are the same, back up src.
auto src_backup = v10.v4s();
mov(src_backup, src);
ld3r({sixth, three, six}, mem[x8]);
movi(zero, 0);
fadd(dst, src, three.v4s());
fmul(src_backup, src_backup, sixth.v4s());
fmax(dst, dst, zero);
fmin(dst, dst, six.v4s());
// maybe don't need v10
fmul(dst, dst, src_backup);
break;
}
default:
XNN_UNREACHABLE;
}
add(x8, x8, sizeof(union xnn_fused_operator_params));
}
mov(x8, x15); // restore x8
}

// Converted from: src/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max)
void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const jit_gemm_params* jit_gemm_params)
{
assert(max_mr <= 6);
assert(nc_mod_nr < 8);
assert(kc != 0);
assert(kc % sizeof(float) == 0);
const float min = jit_gemm_params->f32_minmax.min;
const float max = jit_gemm_params->f32_minmax.max;
const size_t num_fused_operators = jit_gemm_params->num_fused_operators;
const xnn_fused_operator* fused_operators = jit_gemm_params->fused_operators;
// const xnn_fused_operator_params* fused_operator_params = jit_gemm_params->fused_operator_params;

Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
const bool clamp_min = min != -std::numeric_limits<float>::infinity();
Expand Down Expand Up @@ -913,6 +964,33 @@ void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t
}
}

// TODO(zhin): this way of emitting instructions does not allow us to do software pipelining, and requires us to
// back up and restore x8 (pointer to params) constantly. We should instead emit instructions for a single fused op
// across all registers. That can trivially let us pipeline fused operations that emit a single instruction, but does
// not work for operations like hardswish with multiple instructions and dependency chains.
perform_fused_ops(v20.v4s(), v20.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v21.v4s(), v21.v4s(), num_fused_operators, fused_operators);
if (max_mr > 1) {
perform_fused_ops(v22.v4s(), v22.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v23.v4s(), v23.v4s(), num_fused_operators, fused_operators);
}
if (max_mr > 2) {
perform_fused_ops(v24.v4s(), v24.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v25.v4s(), v25.v4s(), num_fused_operators, fused_operators);
}
if (max_mr > 3) {
perform_fused_ops(v26.v4s(), v26.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v27.v4s(), v27.v4s(), num_fused_operators, fused_operators);
}
if (max_mr > 4) {
perform_fused_ops(v28.v4s(), v28.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v29.v4s(), v29.v4s(), num_fused_operators, fused_operators);
}
if (max_mr > 5) {
perform_fused_ops(v30.v4s(), v30.v4s(), num_fused_operators, fused_operators);
perform_fused_ops(v31.v4s(), v31.v4s(), num_fused_operators, fused_operators);
}

// Store full 6 x 8
b_lo(l7);

Expand Down Expand Up @@ -1354,8 +1432,7 @@ xnn_status xnn_generate_f32_gemm_ukernel_upto6x8__aarch64_neonfma_cortex_a75(xnn
using namespace xnnpack::aarch64;
Generator g(code);
assert(params != nullptr);
const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
g.generate(false, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
g.generate(false, max_mr, nc_mod_nr, kc, static_cast<const jit_gemm_params*>(params));
g.finalize();
if (g.error() != xnnpack::Error::kNoError) {
return xnn_status_invalid_state;
Expand All @@ -1368,8 +1445,7 @@ xnn_status xnn_generate_f32_gemm_ukernel_upto6x8__aarch64_neonfma_prfm_cortex_a7
using namespace xnnpack::aarch64;
Generator g(code);
assert(params != nullptr);
const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
g.generate(true, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
g.generate(true, max_mr, nc_mod_nr, kc, static_cast<const jit_gemm_params*>(params));
g.finalize();
if (g.error() != xnnpack::Error::kNoError) {
return xnn_status_invalid_state;
Expand Down
10 changes: 10 additions & 0 deletions src/init.c
Expand Up @@ -3009,6 +3009,13 @@ static void init(void) {
xnn_params.f32.gemm.init.f32 = xnn_init_f32_minmax_scalar_params;
xnn_params.f32.gemm.mr = 6;
xnn_params.f32.gemm.nr = 8;
#if XNN_ENABLE_JIT
// TODO(zhin): remove this after testing
xnn_params.f32.gemm.generator.gemm = xnn_init_hmp_gemm_codegen(xnn_generate_f32_gemm_ukernel_upto6x8__aarch64_neonfma_prfm_cortex_a75);
xnn_params.f32.gemm.generator.igemm = xnn_init_hmp_igemm_codegen(xnn_generate_f32_igemm_ukernel_upto6x8__aarch64_neonfma_prfm_cortex_a75);
xnn_params.f32.gemm.generator.gemm1 = xnn_init_hmp_gemm_codegen(xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75);
xnn_params.f32.gemm.generator.igemm1 = xnn_init_hmp_igemm_codegen(xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75);
#endif
break;
}
#if XNN_MAX_UARCH_TYPES > 1
Expand Down Expand Up @@ -3278,6 +3285,9 @@ static void init(void) {
.init.f32_minmax = xnn_init_f32_minmax_scalar_params,
.element_tile = 8,
};
xnn_params.f32.fused_vadd = (struct constant_parameters) {
.init.f32_constant = xnn_init_f32_constant_params,
};
xnn_params.f32.vdiv = (struct vbinary_parameters) {
.minmax.op_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdiv_minmax_ukernel__neon_x8,
.minmax.opc_ukernel = (xnn_vbinary_ukernel_function) xnn_f32_vdivc_minmax_ukernel__neon_x8,
Expand Down
36 changes: 25 additions & 11 deletions src/operator-run.c
Expand Up @@ -1212,17 +1212,31 @@ void xnn_compute_vmulcaddc(
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
&context->params);
if (context->num_fused_params == 0) {
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
&context->params);
} else {
context->ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->fused_params);
}
}

void xnn_compute_hmp_grouped_batch_igemm(
Expand Down

0 comments on commit b9664e9

Please sign in to comment.