diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e02abfda77..67ad1ccdb99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3800,7 +3800,7 @@ IF(XNNPACK_BUILD_BENCHMARKS) ADD_EXECUTABLE(softmax-bench bench/softmax.cc) TARGET_INCLUDE_DIRECTORIES(softmax-bench PRIVATE .) - TARGET_LINK_LIBRARIES(softmax-bench PRIVATE XNNPACK benchmark::benchmark bench-utils) + TARGET_LINK_LIBRARIES(softmax-bench PRIVATE XNNPACK fp16 benchmark::benchmark bench-utils) ADD_EXECUTABLE(square-bench bench/square.cc) TARGET_INCLUDE_DIRECTORIES(square-bench PRIVATE .) diff --git a/bench/softmax.cc b/bench/softmax.cc index 250a38681b4..89e29a194aa 100644 --- a/bench/softmax.cc +++ b/bench/softmax.cc @@ -11,6 +11,8 @@ #include #include +#include + #include #include @@ -171,6 +173,79 @@ static void xnnpack_softmax_f32(benchmark::State& state) { benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); } +static void xnnpack_softmax_f16(benchmark::State& state) { + const size_t batch_size = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(-100.0f, 100.0f), std::ref(rng)); + auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); + + std::vector input(batch_size * channels + XNN_EXTRA_BYTES / sizeof(uint16_t)); + std::vector output(batch_size * channels); + std::generate(input.begin(), input.end(), std::ref(f16rng)); + std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); + + xnn_status status = xnn_initialize(nullptr /* allocator */); + if (status != xnn_status_success) { + state.SkipWithError("failed to initialize XNNPACK"); + return; + } + + xnn_operator_t softmax_op = nullptr; + status = xnn_create_softmax_nc_f16(0 /* flags */, &softmax_op); + if (status != xnn_status_success || softmax_op == nullptr) { + state.SkipWithError("failed to create SoftMax operator"); + return; + } + + status = xnn_reshape_softmax_nc_f16( + softmax_op, + channels, channels /* input stride */, channels /* output stride */, + batch_size, + /*threadpool=*/nullptr); + if (status != xnn_status_success) { + state.SkipWithError("failed to reshape SoftMax operator"); + return; + } + + status = xnn_setup_softmax_nc_f16( + softmax_op, + input.data(), output.data()); + if (status != xnn_status_success) { + state.SkipWithError("failed to setup SoftMax operator"); + return; + } + + for (auto _ : state) { + status = xnn_run_operator(softmax_op, /*threadpool=*/nullptr); + if (status != xnn_status_success) { + state.SkipWithError("failed to run SoftMax operator"); + return; + } + } + + status = xnn_delete_operator(softmax_op); + if (status != xnn_status_success) { + state.SkipWithError("failed to delete SoftMax operator"); + return; + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = batch_size * channels; + state.counters["elements"] = + benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(uint16_t); + state.counters["bytes"] = + benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); +} + #ifdef BENCHMARK_TENSORFLOW_LITE static void tflite_softmax_f32(benchmark::State& state) { const size_t batch_size = state.range(0); @@ -308,8 +383,9 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) b->Args({257 * 257, 151}); } -BENCHMARK(xnnpack_softmax_qu8)->Apply(CharacteristicArguments)->UseRealTime(); BENCHMARK(xnnpack_softmax_f32)->Apply(CharacteristicArguments)->UseRealTime(); +BENCHMARK(xnnpack_softmax_f16)->Apply(CharacteristicArguments)->UseRealTime(); +BENCHMARK(xnnpack_softmax_qu8)->Apply(CharacteristicArguments)->UseRealTime(); #ifdef BENCHMARK_TENSORFLOW_LITE BENCHMARK(tflite_softmax_f32)->Apply(CharacteristicArguments)->UseRealTime();