From f35779d62a0b3a2e0f6be79a647b1e3acf02129b Mon Sep 17 00:00:00 2001 From: "Gouicem, Mourad" Date: Fri, 9 Mar 2018 08:39:19 -0800 Subject: [PATCH] api: rnn: Adding the new RNN API (experimental) this closes #46 and closes #154 --- doc/mainpage.md | 20 +- include/mkldnn.h | 69 +++ include/mkldnn.hpp | 508 ++++++++++++++- include/mkldnn_types.h | 110 ++++ src/common/c_types_map.hpp | 18 + src/common/memory.cpp | 6 + src/common/memory_desc_wrapper.cpp | 46 ++ src/common/memory_desc_wrapper.hpp | 4 +- src/common/mkldnn_debug.cpp | 12 + src/common/mkldnn_traits.hpp | 1 + src/common/rnn.cpp | 254 ++++++++ src/common/rnn_pd.hpp | 248 ++++++++ src/common/verbose.hpp | 19 + src/cpu/cpu_engine.cpp | 5 + src/cpu/cpu_rnn_pd.hpp | 242 ++++++++ src/cpu/ref_rnn.cpp | 950 +++++++++++++++++++++++++++++ src/cpu/ref_rnn.hpp | 377 ++++++++++++ 17 files changed, 2883 insertions(+), 6 deletions(-) create mode 100644 src/common/rnn.cpp create mode 100644 src/common/rnn_pd.hpp create mode 100644 src/cpu/cpu_rnn_pd.hpp create mode 100644 src/cpu/ref_rnn.cpp create mode 100644 src/cpu/ref_rnn.hpp diff --git a/doc/mainpage.md b/doc/mainpage.md index 9df3d93e6b4..31a06be7f4a 100644 --- a/doc/mainpage.md +++ b/doc/mainpage.md @@ -5,8 +5,9 @@ The Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN) is open source performance library for Deep Learning (DL) applications intended for acceleration of DL frameworks on Intel(R) architecture. Intel MKL-DNN includes highly vectorized and threaded building blocks for implementation of -convolutional neural networks (CNN) with C and C++ interfaces. This -project is created to help the DL community innovate on the Intel(R) processor family. +convolutional neural networks (CNN) and reccurent neural networks (RNN) with +C and C++ interfaces. This project is created to help the DL community innovate +on the Intel(R) processor family. The library supports the most commonly used primitives necessary to accelerate bleeding edge image recognition topologies, including Cifar*, AlexNet*, VGG*, @@ -26,13 +27,21 @@ operations. The library includes the following classes of functions: * Activation - rectified linear unit neuron activation (ReLU) - - softmax + - softmax * Data manipulation - reorder (multi-dimensional transposition/conversion), - sum, - concat - - view + - view + +Also the library contains experimental support of RNN primitives to accelerate +speech recognition and neural machine translation topologies. The experemental +support includes the following classes of functions: + +* RNN + - RNN cell + - LSTM cell Intel MKL DNN primitives implement a plain C/C++ application programming interface (API) that can be used in the existing C/C++ DNN frameworks, as well @@ -128,6 +137,9 @@ The following examples are available in the /examples directory and provide more - C: simple_training.c - C++: simple_training_net.cpp +* Creation of forward propagation of GNMT topology (experimental support) + - C++: simple_rnn.cpp + ### Performance Considerations * Convolution and inner product primitives choose the memory format when you create them with the unspecified memory diff --git a/include/mkldnn.h b/include/mkldnn.h index 5f5a5c04a14..955fcf2fc30 100644 --- a/include/mkldnn.h +++ b/include/mkldnn.h @@ -939,6 +939,75 @@ mkldnn_status_t MKLDNN_API mkldnn_convolution_relu_desc_init( /** @} */ +/** @addtogroup c_api_rnn RNN + * A primitive to compute common recurrent layer. + * @todo add additional description for the group + * @{ */ + +/** + * Initializes a recurrent cell descriptor @p rnn_cell_desc + * using @p rnn_cell_desc, @p kind (possible values are + * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru), + * @p f (possible values are #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( + mkldnn_rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, + unsigned int flags, float alpha, float clipping); + +/** Returns the number of gates of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_gates_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Returns the number of states of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_states_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Initializes a rnn descriptor @p rnn_desc for forward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note if @p prop_kind equals #mkldnn_forward_training, you need to query a + * worskpace memory descriptor before creating the primitive. + * + * @note all memory descriptors except @p src_iter_desc are allowed to be + * initialized with #mkldnn_any value of @p format_kind. */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc); + +/** Initializes a rnn descriptor @p rnn_desc for backward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note all memory descriptors are allowed to be initialized with + * #mkldnn_any value of @p format_kind. */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc, + const mkldnn_memory_desc_t *diff_src_layer_desc, + const mkldnn_memory_desc_t *diff_src_iter_desc, + const mkldnn_memory_desc_t *diff_weights_layer_desc, + const mkldnn_memory_desc_t *diff_weights_iter_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_layer, + const mkldnn_memory_desc_t *diff_dst_iter_desc); + +/** @} */ + /** @} */ /** @addtogroup c_api_engine Engine operations diff --git a/include/mkldnn.hpp b/include/mkldnn.hpp index a751bbd7787..6cff80d0f06 100644 --- a/include/mkldnn.hpp +++ b/include/mkldnn.hpp @@ -123,6 +123,7 @@ class primitive: public handle { batch_normalization = mkldnn_batch_normalization, inner_product = mkldnn_inner_product, convolution_relu = mkldnn_convolution_relu, + rnn = mkldnn_rnn, }; /// A wrapper structure to specify a particular output of a primitive. @@ -247,6 +248,7 @@ inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { } enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, convolution_direct = mkldnn_convolution_direct, convolution_winograd = mkldnn_convolution_winograd, eltwise_relu = mkldnn_eltwise_relu, @@ -264,7 +266,10 @@ enum algorithm { pooling_max = mkldnn_pooling_max, pooling_avg = mkldnn_pooling_avg, pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, - pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, }; inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { @@ -283,6 +288,18 @@ inline mkldnn_batch_normalization_flag_t convert_to_c( return static_cast(aflag); } +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast(adir); +} + enum query { undef = mkldnn_query_undef, @@ -307,6 +324,7 @@ enum query { batch_normalization_d = mkldnn_query_batch_normalization_d, inner_product_d = mkldnn_query_inner_product_d, convolution_relu_d = mkldnn_query_convolution_relu_d, + rnn_d = mkldnn_query_rnn_d, input_pd = mkldnn_query_input_pd, output_pd = mkldnn_query_output_pd, @@ -603,6 +621,14 @@ struct memory: public primitive { ncdhw = mkldnn_ncdhw, oidhw = mkldnn_oidhw, goidhw = mkldnn_goidhw, + ntc = mkldnn_ntc, + tnc = mkldnn_tnc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldigo_p = mkldnn_ldigo_p, + ldgoi = mkldnn_ldgoi, + ldgoi_p = mkldnn_ldgoi_p, + ldgo = mkldnn_ldgo, }; /// A memory descriptor. @@ -2898,6 +2924,486 @@ struct inner_product_backward_weights: public primitive { /// @} +/// @addtogroup cpp_api_rnn RNN +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create an RNN forward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + memory::primitive_desc dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(dst_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + + rnn_forward(const primitive_desc &aprimitive_desc, + const primitive::at &src_layer, const primitive::at &src_iter, + const primitive::at &weights_layer, + const primitive::at &weights_iter, const primitive::at &bias, + const memory &dst_layer, const memory &dst_iter, + const memory &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[5]; + const_mkldnn_primitive_t outputs[3]; + int idx=0; + inputs[idx++] = src_layer.data; + if (!is_null_memory(src_iter.data.primitive)) + inputs[idx++] = src_iter.data; + inputs[idx++] = weights_layer.data; + inputs[idx++] = weights_iter.data; + if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data; + + idx=0; + outputs[idx++] = dst_layer.get(); + if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get(); + if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get(); + + error::wrap_c_api(mkldnn_primitive_create(&result, + aprimitive_desc.get(), inputs, outputs), + "could not create an RNN forward primitive"); + reset(result); + } +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + struct primitive_desc : public handle { + primitive_desc(const desc &adesc, const engine &aengine) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_primitive_desc_create( + &result, &adesc.data, aengine.get(), nullptr), + "could not create an RNN backward primitive descriptor"); + reset(result); + } + + memory::primitive_desc src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc weights_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(dst_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_src_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_src_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone an src_layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_src_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_src_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a src iter primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_weights_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_weights_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_weights_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a weights primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_bias_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_weights_pd), 2); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a bias primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_layer_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_dst_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last layer primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc diff_dst_iter_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t cdesc; + const_mkldnn_primitive_desc_t const_cdesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(diff_dst_pd), 1); + error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc), + "could not clone a dst last iteration primitive descriptor"); + adesc.reset(cdesc); + return adesc; + } + + memory::primitive_desc workspace_primitive_desc() const { + memory::primitive_desc adesc; + mkldnn_primitive_desc_t ldesc; + const_mkldnn_primitive_desc_t const_ldesc = + mkldnn_primitive_desc_query_pd(get(), + mkldnn::convert_to_c(workspace_pd), 0); + error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc), + "could not clone a workspace primitive descriptor"); + adesc.reset(ldesc); + return adesc; + } + + engine get_engine() { return engine::query(*this); } + }; + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &aprimitive_desc, + const primitive::at &src_layer, + const primitive::at &src_iter, + const primitive::at &weights_layer, + const primitive::at &weights_iter, + const primitive::at &bias, + const primitive::at &dst_layer, + const primitive::at &dst_iter, + const memory &diff_src_layer, + const memory &diff_src_iter, + const memory &diff_weights_layer, + const memory &diff_weights_iter, + const memory &diff_bias, + const primitive::at &diff_dst_layer, + const primitive::at &diff_dst_iter, + const primitive::at &workspace) { + mkldnn_primitive_t result; + mkldnn_primitive_at_t inputs[10]; + const_mkldnn_primitive_t outputs[5]; + int idx=0; + inputs[idx] = src_layer.data; + if (!is_null_memory(src_iter.data.primitive)) + inputs[idx++] = src_iter.data; + inputs[idx++] = weights_layer.data; + inputs[idx++] = weights_iter.data; + if (!is_null_memory(bias.data.primitive)) + inputs[idx++] = bias.data; + inputs[idx] = dst_layer.data; + if (!is_null_memory(dst_iter.data.primitive)) + inputs[idx++] = dst_iter.data; + inputs[idx] = diff_dst_layer.data; + if (!is_null_memory(diff_dst_iter.data.primitive)) + inputs[idx++] = diff_dst_iter.data; + inputs[idx] = workspace.data; + + idx = 0; + outputs[idx] = diff_src_layer.get(); + if (!is_null_memory(diff_src_iter.get())) + outputs[idx++] = diff_src_iter.get(); + outputs[idx] = diff_weights_layer.get(); + outputs[idx] = diff_weights_iter.get(); + if (!is_null_memory(diff_bias.get())) outputs[idx] = diff_bias.get(); + error::wrap_c_api(mkldnn_primitive_create(&result, + aprimitive_desc.get(), inputs, outputs), + "could not create an RNN backward primitive"); + reset(result); + } +}; + +/// @} /// @} Primitives /// @addtogroup cpp_api_stream Stream diff --git a/include/mkldnn_types.h b/include/mkldnn_types.h index 5fc89e8e993..8a7d717b2f4 100644 --- a/include/mkldnn_types.h +++ b/include/mkldnn_types.h @@ -252,6 +252,26 @@ typedef enum { /** 6D weight tensor in the @c goidhw format with extra dimension for * groups */ mkldnn_goidhw, + /** 3D data tensor in the format (batch, seq_length, input channels). */ + mkldnn_ntc, + /** 3D data tensor in the format (seq_length, batch, input channels). */ + mkldnn_tnc, + /** 5D states tensor in the format (num_layers, num_directions, num_states, + * batch, state channels). */ + mkldnn_ldsnc, + /** 5D weights tensor in the format (num_layers, num_directions, + * input_chanels, num_gates, output_channels). */ + mkldnn_ldigo, + /** 5D weights tensor in the blocked format. */ + mkldnn_ldigo_p, + /** 5D weights tensor in the format (num_layers, num_directions, num_gates, + * output_channels, input_chanels). */ + mkldnn_ldgoi, + /** 5D weights tensor in the blocked format. */ + mkldnn_ldgoi_p, + /** 4D bias tensor in the format (num_layers, num_directions, num_gates, + * output_channels). */ + mkldnn_ldgo, /** 4D weights tensor in the oihw format with input channels data laid out * in memory in 8-element blocks. */ mkldnn_oIhw8i = mkldnn_nChw8c, @@ -327,10 +347,13 @@ typedef enum { mkldnn_inner_product, /** A convolution primitive merged with relu */ mkldnn_convolution_relu, + /** A rnn primitive. */ + mkldnn_rnn, } mkldnn_primitive_kind_t; /** Kinds of algorithms. */ typedef enum { + mkldnn_alg_kind_undef, /** Direct convolution */ mkldnn_convolution_direct = 1, /** Winograd convolution */ @@ -366,6 +389,12 @@ typedef enum { mkldnn_lrn_across_channels = 65, /** LRN within a single channel */ mkldnn_lrn_within_channel = 66, + /** RNN cell */ + mkldnn_vanilla_rnn = 80, + /** LSTM cell */ + mkldnn_vanilla_lstm = 81, + /** GRU cell */ + mkldnn_vanilla_gru = 82, } mkldnn_alg_kind_t; /** Flags for batch-normalization primititve. */ @@ -714,6 +743,86 @@ typedef struct { float negative_slope; } mkldnn_convolution_relu_desc_t; +/** Flags for RNN cell. */ +typedef enum { + mkldnn_rnn_cell_with_relu = 0x1U, + mkldnn_rnn_cell_with_clipping = 0x2U, +} mkldnn_rnn_cell_flags_t; + +typedef struct { + /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, + * #mkldnn_vanilla_lstm, or #mkldnn_vanilla_gru. */ + mkldnn_alg_kind_t cell_kind; + /** Activation function used. Must be one of #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh. */ + mkldnn_alg_kind_t activation_kind; + /** RNN cell flags */ + unsigned int flags; + /** alpha is a negative slope parameter (used only if + * (flags & #mkldnn_rnn_cell_with_relu) != 0) */ + float alpha; + /** clipping parameter (used only if + * (flags & #mkldnn_rnn_cell_with_clipping) != 0) */ + float clipping; +} mkldnn_rnn_cell_desc_t; + +/** A direction of RNN primitive execution */ +typedef enum { + /* Unidirectional execution of RNN primitive from left to right. */ + mkldnn_unidirectional_left2right, + /* Unidirectional execution of RNN primitive from right to left. */ + mkldnn_unidirectional_right2left, + /* Bidirectional execution of RNN primitive with concatenation of the + * results. */ + mkldnn_bidirectional_concat, + /* Bidirectional execution of RNN primitive with summation of the + * results. */ + mkldnn_bidirectional_sum, + mkldnn_unidirectional = mkldnn_unidirectional_left2right, +} mkldnn_rnn_direction_t; + +/** A descriptor for an rnn operation */ +typedef struct { + /** The kind of primitive. Used for self identifying the primitive + * descriptor. Must be #mkldnn_rnn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward. */ + mkldnn_prop_kind_t prop_kind; + /** The RNN cell desc. */ + mkldnn_rnn_cell_desc_t cell_desc; + /** The direction of RNN primitive execution. */ + mkldnn_rnn_direction_t direction; + /** Source layer memory descriptor. */ + mkldnn_memory_desc_t src_layer_desc; + /** Source iteration memory descriptor. */ + mkldnn_memory_desc_t src_iter_desc; + /** Weights layer memory descriptor. */ + mkldnn_memory_desc_t weights_layer_desc; + /** Weights iteration memory descriptor. */ + mkldnn_memory_desc_t weights_iter_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Destination layer memory descriptor. */ + mkldnn_memory_desc_t dst_layer_desc; + /** Destination iter memory descriptor. */ + mkldnn_memory_desc_t dst_iter_desc; + /** Source gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_src_layer_desc; + /** Source gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_src_iter_desc; + /** Weights gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_weights_layer_desc; + /** Weights gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_weights_iter_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_dst_layer_desc; + /** Destination gradient iteration memory descriptor. */ + mkldnn_memory_desc_t diff_dst_iter_desc; +} mkldnn_rnn_desc_t; + /** @} */ /** @addtogroup c_api_engine_types Engine @@ -898,6 +1007,7 @@ typedef enum { mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ mkldnn_query_inner_product_d, /**< inner product descriptor */ mkldnn_query_convolution_relu_d, /**< convolution-relu descriptor */ + mkldnn_query_rnn_d, /**< rnn descriptor */ /* (memory) primitive descriptor section */ mkldnn_query_some_pd = 128, /**< stub */ diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 878e54044fe..5c3583f9d5a 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -59,6 +59,7 @@ namespace prop_kind { using alg_kind_t = mkldnn_alg_kind_t; namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; const alg_kind_t convolution_direct = mkldnn_convolution_direct; const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; @@ -77,6 +78,9 @@ namespace alg_kind { const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; } using data_type_t = mkldnn_data_type_t; @@ -146,6 +150,14 @@ namespace memory_format { const memory_format_t ncdhw = mkldnn_ncdhw; const memory_format_t oidhw = mkldnn_oidhw; const memory_format_t goidhw = mkldnn_goidhw; + const memory_format_t ntc = mkldnn_ntc; + const memory_format_t tnc = mkldnn_tnc; + const memory_format_t ldsnc = mkldnn_ldsnc; + const memory_format_t ldigo = mkldnn_ldigo; + const memory_format_t ldigo_p = mkldnn_ldigo_p; + const memory_format_t ldgoi = mkldnn_ldgoi; + const memory_format_t ldgoi_p = mkldnn_ldgoi_p; + const memory_format_t ldgo = mkldnn_ldgo; } using padding_kind_t = mkldnn_padding_kind_t; @@ -176,6 +188,7 @@ namespace primitive_kind { const primitive_kind_t batch_normalization = mkldnn_batch_normalization; const primitive_kind_t inner_product = mkldnn_inner_product; const primitive_kind_t convolution_relu = mkldnn_convolution_relu; + const primitive_kind_t rnn = mkldnn_rnn; } using query_t = mkldnn_query_t; @@ -203,6 +216,7 @@ namespace query { const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; const query_t inner_product_d = mkldnn_query_inner_product_d; const query_t convolution_relu_d = mkldnn_query_convolution_relu_d; + const query_t rnn_d = mkldnn_query_rnn_d; const query_t some_pd = mkldnn_query_some_pd; const query_t input_pd = mkldnn_query_input_pd; @@ -228,6 +242,10 @@ using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; using inner_product_desc_t = mkldnn_inner_product_desc_t; using convolution_relu_desc_t = mkldnn_convolution_relu_desc_t; +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + /* C op_desc_t, which eventually are just (void*) */ using c_op_desc_t = mkldnn_op_desc_t; using const_c_op_desc_t = const_mkldnn_op_desc_t; diff --git a/src/common/memory.cpp b/src/common/memory.cpp index 21ba893c5d5..376943dd3df 100644 --- a/src/common/memory.cpp +++ b/src/common/memory.cpp @@ -102,6 +102,12 @@ status_t mkldnn_memory_desc_init(memory_desc_t *memory_desc, int ndims, case ncdhw: case goidhw: case oidhw: + case ntc: + case tnc: + case ldsnc: + case ldigo: + case ldgoi: + case ldgo: status = memory_desc_wrapper::compute_blocking(md); break; /* not enough information */ diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index e0bb2584e88..7e96560923f 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -493,6 +493,46 @@ status_t fill_gOhIw16o4i(memory_desc_t &md) { return fill_contiguous_blocked(md, block_dims, perm); } +status_t fill_ntc(memory_desc_t &md) { + if (md.ndims != 3) return invalid_arguments; + + const int perm[3] = { 1, 0, 2 }; + return fill_nonblocked(md, perm); +} + +status_t fill_tnc(memory_desc_t &md) { + if (md.ndims != 3) return invalid_arguments; + const int perm[3] = { 0, 1, 2 }; + return fill_nonblocked(md, perm); +} + +status_t fill_ldsnc(memory_desc_t &md) { + if (md.ndims != 5) return invalid_arguments; + const int perm[5] = { 0, 1, 2, 3, 4 }; + return fill_nonblocked(md, perm); +} + +status_t fill_ldigo(memory_desc_t &md) { + if (md.ndims != 5) return invalid_arguments; + + const int perm[5] = { 0, 1, 2, 3, 4 }; + return fill_nonblocked(md, perm); +} + +status_t fill_ldgoi(memory_desc_t &md) { + if (md.ndims != 5) return invalid_arguments; + + const int perm[5] = { 0, 1, 3, 4, 2 }; + return fill_nonblocked(md, perm); +} + +status_t fill_ldgo(memory_desc_t &md) { + if (md.ndims != 4) return invalid_arguments; + + const int perm[4] = { 0, 1, 2, 3 }; + return fill_nonblocked(md, perm); +} + } status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc) @@ -544,6 +584,12 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc) case ncdhw: return fill_ncdhw(memory_desc); case oidhw: return fill_oidhw(memory_desc); case goidhw: return fill_goidhw(memory_desc); + case ntc: return fill_ntc(memory_desc); + case tnc: return fill_tnc(memory_desc); + case ldsnc: return fill_ldsnc(memory_desc); + case ldigo: return fill_ldigo(memory_desc); + case ldgoi: return fill_ldgoi(memory_desc); + case ldgo: return fill_ldgo(memory_desc); default: break; } diff --git a/src/common/memory_desc_wrapper.hpp b/src/common/memory_desc_wrapper.hpp index 243a405621f..83da2479edf 100644 --- a/src/common/memory_desc_wrapper.hpp +++ b/src/common/memory_desc_wrapper.hpp @@ -70,6 +70,7 @@ struct memory_desc_wrapper: public c_compatible { size_t size() const { using namespace mkldnn::impl::memory_format; if (is_zero() || format() == memory_format::any) return 0; + assert(utils::one_of(format(), blocked, x, nc, nchw, nhwc, chwn, nChw8c, nChw16c, oi, io, oihw, ihwo, hwio, hwigo, oIhw8i, oIhw16i, OIhw8i8o, OIhw16i16o, OIhw8i16o2i, OIhw8o16i2o, @@ -77,7 +78,8 @@ struct memory_desc_wrapper: public c_compatible { OhIw16o4i, OIhw4i16o4i, goihw, gOIhw8i8o, gOIhw16i16o, gOIhw8i16o2i, gOIhw8o16i2o, gOIhw8o8i, gOIhw16o16i, gOihw8o, gOihw16o, gOhwi8o, gOhwi16o, gOhIw16o4i, IOhw16o16i, - gIOhw16o16i, gOIhw4i16o4i, ncdhw, oidhw, goidhw)); + gIOhw16o16i, gOIhw4i16o4i, ncdhw, oidhw, goidhw, + ntc, tnc, ldsnc, ldigo, ldgoi, ldgo)); if (blocking_desc().offset_padding != 0) return 0; diff --git a/src/common/mkldnn_debug.cpp b/src/common/mkldnn_debug.cpp index e08440fdbef..5c6737a19ac 100644 --- a/src/common/mkldnn_debug.cpp +++ b/src/common/mkldnn_debug.cpp @@ -103,6 +103,14 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) { if (v == mkldnn_ncdhw) return "ncdhw"; if (v == mkldnn_oidhw) return "oidhw"; if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldigo_p) return "ldigo_p"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgoi_p) return "ldgoi_p"; + if (v == mkldnn_ldgo) return "ldgo"; assert(!"unknown fmt"); return "unknown fmt"; } @@ -138,6 +146,7 @@ const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { if (v == mkldnn_batch_normalization) return "batch_normalization"; if (v == mkldnn_inner_product) return "inner_product"; if (v == mkldnn_convolution_relu) return "convolution_relu"; + if (v == mkldnn_rnn) return "rnn"; assert(!"unknown prim_kind"); return "unknown prim_kind"; } @@ -161,6 +170,9 @@ const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { if (v == mkldnn_pooling_avg) return "pooling_avg"; if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; assert(!"unknown alg_kind"); return "unknown alg_kind"; } diff --git a/src/common/mkldnn_traits.hpp b/src/common/mkldnn_traits.hpp index b73121a0ab2..0fda280d396 100644 --- a/src/common/mkldnn_traits.hpp +++ b/src/common/mkldnn_traits.hpp @@ -64,6 +64,7 @@ PKIND_TRAITS_INST(lrn); PKIND_TRAITS_INST(batch_normalization); PKIND_TRAITS_INST(inner_product); PKIND_TRAITS_INST(convolution_relu); +PKIND_TRAITS_INST(rnn); #undef PKIND_TRAITS_INST } diff --git a/src/common/rnn.cpp b/src/common/rnn.cpp new file mode 100644 index 00000000000..b069f849ad0 --- /dev/null +++ b/src/common/rnn.cpp @@ -0,0 +1,254 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + rnn_desc_t rd = {}; + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru) + && implication(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh)); + if (!args_ok) + return status::invalid_arguments; + + mkldnn_rnn_cell_desc_t rcd = {}; + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return status::success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + int DIC = 0, L = 0; + if (weights_layer_desc && weights_layer_desc->ndims) { + DIC = weights_layer_desc->dims[4]; + L = weights_layer_desc->dims[0]; + } else if (weights_iter_desc && weights_iter_desc->ndims) { + DIC = weights_iter_desc->dims[4]; + L = weights_iter_desc->dims[0]; + } else { + assert(!"cannot query cell state size"); + return unimplemented; + } + + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + const int DLC = (direction == mkldnn_bidirectional_concat ? 2 : 1) * DIC; + + args_ok = args_ok && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && DLC == dst_layer_desc->dims[2] && L == weights_iter_desc->dims[0] + && implication(!is_zero_md(dst_iter_desc), true + && DIC == dst_iter_desc->dims[4] + && L == dst_iter_desc->dims[0]) + && implication(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && implication( + !is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]); + if (!args_ok) + return invalid_arguments; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer); + if (!args_ok) + return invalid_arguments; + + int DIC = 0, L = 0; + if (weights_layer_desc && weights_layer_desc->ndims) { + DIC = weights_layer_desc->dims[4]; + L = weights_layer_desc->dims[0]; + } else if (weights_iter_desc && weights_iter_desc->ndims) { + DIC = weights_iter_desc->dims[4]; + L = weights_iter_desc->dims[0]; + } else { + assert(!"cannot query cell state size"); + return unimplemented; + } + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int DLC = (direction == mkldnn_bidirectional_concat ? 2 : 1) * DIC; + + args_ok = args_ok && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && DLC == dst_layer_desc->dims[2] && L == weights_iter_desc->dims[0] + && implication(!is_zero_md(dst_iter_desc), true + && DIC == dst_iter_desc->dims[4] + && L == dst_iter_desc->dims[0]) + && implication(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && implication( + !is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]); + if (!args_ok) + return invalid_arguments; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/src/common/rnn_pd.hpp b/src/common/rnn_pd.hpp new file mode 100644 index 00000000000..cb2027ccb01 --- /dev/null +++ b/src/common/rnn_pd.hpp @@ -0,0 +1,248 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_pd.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +// struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(mkldnn::impl::engine_t *engine, const rnn_desc_t *adesc, + const primitive_attr_t *attr, const rnn_pd_t *hint_pd) + : primitive_desc_t(engine, attr, primitive_kind::rnn) + , desc_(*adesc) + , hint_pd_(hint_pd) {} + virtual ~rnn_pd_t() {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override { + return reinterpret_cast(this->desc()); + } + virtual void init_info() override { init_info_rnn(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + inline bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + inline int ws_states_size() { + int n_layer = L(); + int n_direction = D(); + int n_iter = T(); + int n_states = S(); + int batch = MB(); + int s_size = DIC(); + /// @todo handle the case where !(state_size == hidden_size == + /// input_size) + assert(SLC() == DIC()); + + return (n_layer + 1) * n_direction * (n_iter + 1) * n_states * batch + * s_size; + } + + inline int ws_diff_states_size() { + int n_layer = L(); + int n_direction = D(); + int n_iter = T(); + int n_states = S(); + int batch = MB(); + int s_size = DIC(); + /// @todo handle the case where !(state_size == hidden_size == + /// input_size) + assert(SLC() == DIC()); + + return (n_layer + 1) * n_direction * (n_iter + 1) * (n_states + 1) + * batch * s_size; + } + + inline int ws_gates_size() { + int n_layer = L(); + int n_direction = D(); + int n_iter = T(); + int n_gates = G(); + int batch = MB(); + int s_size = DIC(); + /// @todo handle the case where !(state_size == hidden_size == + /// input_size) + assert(SLC() == DIC()); + + return n_layer * n_direction * n_iter * batch * n_gates * s_size; + } + + inline void set_ws_offsets(int &ws_gates_offset, int &ws_states_offset, + int &ws_diff_states_offset) { + const int page_size = 4096; // 2097152; + ws_gates_offset + = 0; // assumes the workspace base pointer is page aligned + ws_states_offset = utils::rnd_up(ws_gates_size(), page_size); + ws_diff_states_offset + = utils::rnd_up(ws_states_offset + ws_states_size(), page_size); + } + + inline int get_ws_size() { + int ws_gates_offset, ws_states_offset, ws_diff_states_offset; + set_ws_offsets( + ws_gates_offset, ws_states_offset, ws_diff_states_offset); + return ws_diff_states_offset + ws_diff_states_size(); + } + + int T() const { return desc_.src_layer_desc.dims[0]; } + int MB() const { return desc_.src_layer_desc.dims[1]; } + + int L() const { return desc_.weights_layer_desc.dims[0]; } + int D() const { return desc_.weights_layer_desc.dims[1]; } + + int SIC() const { return desc_.weights_iter_desc.dims[2]; } + + int SLC() const { return desc_.weights_layer_desc.dims[2]; } + int G() const { return desc_.weights_layer_desc.dims[3]; } + int DIC() const { return desc_.weights_layer_desc.dims[4]; } + + int DLC() const { return desc_.dst_layer_desc.dims[2]; } + + int S() const { return mkldnn_rnn_cell_get_states_count(&desc_.cell_desc); } + + bool with_bias() const { + return !memory_desc_wrapper(desc_.bias_desc).is_zero(); + } + + bool with_src_iter() const { + return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); + } + + bool with_dst_iter() const { + return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); + } + + mkldnn::impl::alg_kind_t cell_kind() const { + return desc_.cell_desc.cell_kind; + } + mkldnn::impl::alg_kind_t activation_kind() const { + return desc_.cell_desc.activation_kind; + } + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_pd_t *hint_pd_; +}; + +struct rnn_fwd_pd_t : public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + using rnn_pd_t::rnn_pd_t; + virtual ~rnn_fwd_pd_t() {} + + virtual const memory_pd_t *input_pd(int index = 0) const override { + switch (index) { + case 0: return src_pd(0); + case 1: return src_pd(1); + case 2: return weights_pd(0); + case 3: return weights_pd(1); + case 4: return weights_pd(2); + default: return nullptr; + } + } + + virtual const memory_pd_t *output_pd(int index = 0) const override { + switch (index) { + case 0: return dst_pd(0); + case 1: return dst_pd(1); + case 2: return workspace_pd(); + default: return nullptr; + } + } + + virtual int n_inputs() const override { + return 3 + with_bias() + with_src_iter(); + } + + virtual int n_outputs() const override { + return 1 + with_dst_iter() + is_training(); + } + + int ws_idx() const { return 1 + with_dst_iter(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_bwd_pd_t hint_class; + + using rnn_pd_t::rnn_pd_t; + virtual ~rnn_bwd_pd_t() {} + + virtual const memory_pd_t *input_pd(int index = 0) const override { + switch (index) { + case 0: return src_pd(0); + case 1: return src_pd(1); + case 2: return weights_pd(0); + case 3: return weights_pd(1); + case 4: return weights_pd(2); + case 5: return dst_pd(0); + case 6: return dst_pd(1); + case 7: return diff_dst_pd(0); + case 8: return diff_dst_pd(1); + case 9: return workspace_pd(); + default: return nullptr; + } + } + + virtual const memory_pd_t *output_pd(int index = 0) const override { + switch (index) { + case 0: return diff_src_pd(0); + case 1: return diff_src_pd(1); + case 2: return diff_weights_pd(0); + case 3: return diff_weights_pd(1); + case 4: return diff_weights_pd(2); + default: return nullptr; + } + } + + virtual int n_inputs() const override { + return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); + } + virtual int n_outputs() const override { + return 3 + with_src_iter() + with_bias(); + } + + int ws_idx() const { + return 5 + with_src_iter() + with_bias() + 2 * with_dst_iter(); + } +}; +} +} + +#endif diff --git a/src/common/verbose.hpp b/src/common/verbose.hpp index e0ed8a0d98c..c6a6f0c4f79 100644 --- a/src/common/verbose.hpp +++ b/src/common/verbose.hpp @@ -269,6 +269,24 @@ template static void init_info_softmax(pd_t *s, char *buffer) { aux_str, prb_str); } +/// @todo print meaningful data +template static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + alg_kind_t alg_kind = s->desc()->cell_desc.cell_kind; + snprintf(aux_str, MKLDNN_VERBOSE_AUX_LEN, + "alg:%s", mkldnn_alg_kind2str(alg_kind)); + + snprintf(prb_str, MKLDNN_VERBOSE_PRB_LEN, + "l%dd%dmb%dt%d_ic%dsc%doc%d_wi%dws%d", + s->L(), s->D(), s->T(), s->MB(), + s->SLC(), s->DIC(), s->DIC(), + s->SLC(), s->SIC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + #else /* !defined(DISABLE_VERBOSE) */ #define MKLDNN_VERBOSE_BUF_LEN 1 @@ -285,6 +303,7 @@ DEFINE_STUB(lrn); DEFINE_STUB(mem); DEFINE_STUB(pool); DEFINE_STUB(softmax); +DEFINE_STUB(rnn); #undef DEFINE_STUB #endif /* !defined(DISABLE_VERBOSE) */ diff --git a/src/cpu/cpu_engine.cpp b/src/cpu/cpu_engine.cpp index db742e18552..7c1b4b125b8 100644 --- a/src/cpu/cpu_engine.cpp +++ b/src/cpu/cpu_engine.cpp @@ -24,6 +24,8 @@ #include "cpu_concat.hpp" #include "cpu_sum.hpp" +#include "cpu/ref_rnn.hpp" + #include "cpu/jit_avx512_core_u8s8s32x_1x1_convolution.hpp" #include "cpu/jit_avx512_common_1x1_convolution.hpp" #include "cpu/jit_avx512_common_convolution_winograd.hpp" @@ -82,6 +84,9 @@ using namespace mkldnn::impl::data_type; #define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> static const pd_create_f cpu_impl_list[] = { + /* RNN */ + INSTANCE(ref_rnn_fwd_t), + INSTANCE(ref_rnn_bwd_t), /* conv 3d */ INSTANCE(ref_convolution_3d_fwd_t), INSTANCE(ref_convolution_3d_fwd_t), diff --git a/src/cpu/cpu_rnn_pd.hpp b/src/cpu/cpu_rnn_pd.hpp new file mode 100644 index 00000000000..effa0dd720f --- /dev/null +++ b/src/cpu/cpu_rnn_pd.hpp @@ -0,0 +1,242 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "cpu_engine.hpp" +#include "cpu_memory.hpp" +#include "cpu_primitive.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using cpu_memory_pd_t = cpu_memory_t::pd_t; + + cpu_rnn_fwd_pd_t(engine_t *engine, const rnn_desc_t *adesc, + const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_layer_pd_(engine, &desc_.src_layer_desc) + , src_iter_pd_(engine, &desc_.src_iter_desc) + , weights_layer_pd_(engine, &desc_.weights_layer_desc) + , weights_iter_pd_(engine, &desc_.weights_iter_desc) + , bias_pd_(engine, &desc_.bias_desc) + , dst_layer_pd_(engine, &desc_.dst_layer_desc) + , dst_iter_pd_(engine, &desc_.dst_iter_desc) + , ws_pd_(engine_) {} + virtual ~cpu_rnn_fwd_pd_t() {} + + virtual const cpu_memory_pd_t *src_pd(int index = 0) const override { + if (index == 0) + return &src_layer_pd_; + if (index == 1 && this->with_src_iter()) + return &src_iter_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override { + if (index == 0) + return &weights_layer_pd_; + if (index == 1) + return &weights_iter_pd_; + if (index == 2 && this->with_bias()) + return &bias_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override { + if (index == 0) + return &dst_layer_pd_; + if (index == 1 && this->with_dst_iter()) + return &dst_iter_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override { + return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr; + } + +protected: + cpu_memory_pd_t src_layer_pd_; + cpu_memory_pd_t src_iter_pd_; + cpu_memory_pd_t weights_layer_pd_; + cpu_memory_pd_t weights_iter_pd_; + cpu_memory_pd_t bias_pd_; + cpu_memory_pd_t dst_layer_pd_; + cpu_memory_pd_t dst_iter_pd_; + cpu_memory_pd_t ws_pd_; + + virtual status_t set_default_params() { + using namespace memory_format; + if (src_layer_pd_.desc()->format == any) + CHECK(src_layer_pd_.set_format(tnc)); + if (weights_layer_pd_.desc()->format == any) + CHECK(weights_layer_pd_.set_format(ldigo)); + if (weights_iter_pd_.desc()->format == any) + CHECK(weights_iter_pd_.set_format(ldigo)); + if (dst_layer_pd_.desc()->format == any) + CHECK(dst_layer_pd_.set_format(tnc)); + + // Optional parameters + if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any)) + CHECK(src_iter_pd_.set_format(ldsnc)); + if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any)) + CHECK(bias_pd_.set_format(ldgo)); + if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any)) + CHECK(dst_iter_pd_.set_format(ldsnc)); + + return status::success; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using cpu_memory_pd_t = cpu_memory_t::pd_t; + + cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc, + const primitive_attr_t *attr, const rnn_bwd_pd_t *hint_bwd_pd) + : rnn_bwd_pd_t(engine, adesc, attr, hint_bwd_pd) + , src_layer_pd_(engine, &desc_.src_layer_desc) + , src_iter_pd_(engine, &desc_.src_iter_desc) + , weights_layer_pd_(engine, &desc_.weights_layer_desc) + , weights_iter_pd_(engine, &desc_.weights_iter_desc) + , bias_pd_(engine, &desc_.bias_desc) + , dst_layer_pd_(engine, &desc_.dst_layer_desc) + , dst_iter_pd_(engine, &desc_.dst_iter_desc) + , diff_src_layer_pd_(engine, &desc_.diff_src_layer_desc) + , diff_states_pd_(engine, &desc_.diff_src_iter_desc) + , diff_weights_layer_pd_(engine, &desc_.diff_weights_layer_desc) + , diff_weights_iter_pd_(engine, &desc_.diff_weights_iter_desc) + , diff_bias_pd_(engine, &desc_.diff_bias_desc) + , diff_dst_layer_pd_(engine, &desc_.diff_dst_layer_desc) + , diff_dst_iter_pd_(engine, &desc_.diff_dst_iter_desc) + , ws_pd_(engine_) {} + virtual ~cpu_rnn_bwd_pd_t() {} + + virtual const cpu_memory_pd_t *src_pd(int index = 0) const override { + if (index == 0) + return &src_layer_pd_; + if (index == 1 && this->with_src_iter()) + return &src_iter_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override { + if (index == 0) + return &weights_layer_pd_; + if (index == 1) + return &weights_iter_pd_; + if (index == 2 && this->with_bias()) + return &bias_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override { + if (index == 0) + return &dst_layer_pd_; + if (index == 1 && this->with_dst_iter()) + return &dst_iter_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override { + if (index == 0) + return &diff_src_layer_pd_; + if (index == 1 && this->with_src_iter()) + return &diff_states_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *diff_weights_pd( + int index = 0) const override { + if (index == 0) + return &diff_weights_layer_pd_; + if (index == 1) + return &diff_weights_iter_pd_; + if (index == 2 && this->with_bias()) + return &diff_bias_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override { + if (index == 0) + return &diff_dst_layer_pd_; + if (index == 1 && this->with_dst_iter()) + return &diff_dst_iter_pd_; + return nullptr; + } + virtual const cpu_memory_pd_t *workspace_pd(int index = 0) const override { + return (index == 0 && !ws_pd_.is_zero()) ? &ws_pd_ : nullptr; + } + +protected: + cpu_memory_pd_t src_layer_pd_; + cpu_memory_pd_t src_iter_pd_; + cpu_memory_pd_t weights_layer_pd_; + cpu_memory_pd_t weights_iter_pd_; + cpu_memory_pd_t bias_pd_; + cpu_memory_pd_t dst_layer_pd_; + cpu_memory_pd_t dst_iter_pd_; + cpu_memory_pd_t diff_src_layer_pd_; + cpu_memory_pd_t diff_states_pd_; + cpu_memory_pd_t diff_weights_layer_pd_; + cpu_memory_pd_t diff_weights_iter_pd_; + cpu_memory_pd_t diff_bias_pd_; + cpu_memory_pd_t diff_dst_layer_pd_; + cpu_memory_pd_t diff_dst_iter_pd_; + cpu_memory_pd_t ws_pd_; + + virtual status_t set_default_params() { + using namespace memory_format; + if (src_layer_pd_.desc()->format == any) + CHECK(src_layer_pd_.set_format(tnc)); + if (diff_src_layer_pd_.desc()->format == any) + CHECK(diff_src_layer_pd_.set_format(tnc)); + if (weights_layer_pd_.desc()->format == any) + CHECK(weights_layer_pd_.set_format(ldgoi)); + if (diff_weights_layer_pd_.desc()->format == any) + CHECK(diff_weights_layer_pd_.set_format(ldigo)); + if (weights_iter_pd_.desc()->format == any) + CHECK(weights_iter_pd_.set_format(ldgoi)); + if (diff_weights_iter_pd_.desc()->format == any) + CHECK(diff_weights_iter_pd_.set_format(ldigo)); + if (dst_layer_pd_.desc()->format == any) + CHECK(dst_layer_pd_.set_format(tnc)); + if (diff_dst_layer_pd_.desc()->format == any) + CHECK(diff_dst_layer_pd_.set_format(tnc)); + + // Optional parameters + if ((!src_iter_pd_.is_zero()) && (src_iter_pd_.desc()->format == any)) + CHECK(src_iter_pd_.set_format(ldsnc)); + if ((!diff_states_pd_.is_zero()) + && (diff_states_pd_.desc()->format == any)) + CHECK(diff_states_pd_.set_format(ldsnc)); + if ((!bias_pd_.is_zero()) && (bias_pd_.desc()->format == any)) + CHECK(bias_pd_.set_format(ldgo)); + if ((!diff_bias_pd_.is_zero()) && (diff_bias_pd_.desc()->format == any)) + CHECK(diff_bias_pd_.set_format(ldgo)); + if ((!dst_iter_pd_.is_zero()) && (dst_iter_pd_.desc()->format == any)) + CHECK(diff_dst_iter_pd_.set_format(ldsnc)); + if ((!diff_dst_iter_pd_.is_zero()) + && (diff_dst_iter_pd_.desc()->format == any)) + CHECK(diff_dst_iter_pd_.set_format(ldsnc)); + + return status::success; + } +}; +} +} +} + +#endif diff --git a/src/cpu/ref_rnn.cpp b/src/cpu/ref_rnn.cpp new file mode 100644 index 00000000000..741401f6b3d --- /dev/null +++ b/src/cpu/ref_rnn.cpp @@ -0,0 +1,950 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_traits.hpp" +#include "type_helpers.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace prop_kind; +using namespace alg_kind; + +#define AOC array_offset_calculator + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_fwd(s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_bwd(dd, s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return tanh_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return tanh_bwd(dd, s); +} + +//************************* Cell execution *************************// +/// @todo shall this be templated on activation function to enable svml calls +/// particularly? +template <> +elemwise_sig(_ref_rnn_common_t::rnn_elemwise) { + AOC ws_gates(ws_gates_, batch, n_gates, s_size); + AOC bias(bias_, n_gates, s_size); + AOC states_t_l(states_t_l_, n_states, batch, s_size); +#pragma omp parallel for + for (int i = 0; i < batch; i++) { + for (int j = 0; j < s_size; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(0, i, j) = h; + } + } +} + +template <> +elemwise_sig(_ref_rnn_common_t::rnn_elemwise) { + AOC ws_gates(ws_gates_, batch, n_gates, s_size); + AOC diff_states_tp1_l( + diff_states_tp1_l_, n_states + 1, batch, s_size); + AOC diff_states_t_lp1( + diff_states_t_lp1_, n_states + 1, batch, s_size); +#pragma omp parallel for + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < s_size; ++j) { + const float dH = diff_states_t_lp1(n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + } +} + +template <> +elemwise_sig(_ref_rnn_common_t::lstm_elemwise) { + AOC ws_gates(ws_gates_, batch, n_gates, s_size); + AOC bias(bias_, n_gates, s_size); + AOC states_t_l(states_t_l_, n_states, batch, s_size); + AOC states_tm1_l(states_tm1_l_, n_states, batch, s_size); + +#pragma omp parallel for + for (int i = 0; i < batch; i++) { +#pragma omp simd + for (int j = 0; j < s_size; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = logistic_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = tanh_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 0, j) * states_tm1_l(1, i, j) + + ws_gates(i, 1, j) * ws_gates(i, 3, j); + states_t_l(0, i, j) = ws_gates(i, 2, j) * tanh_fwd(tmp); + states_t_l(1, i, j) = tmp; + } + } +} + +#if 0 +/// @todo GRU mixes gemms and elemwise with gate 1. +template<> +elemwise_sig(_ref_rnn_common_t::gru_elemwise){ + AOC ws_gates(ws_gates_, batch, n_gates, s_size); + AOC bias(bias_, n_gates, s_size); + AOC states_t_l(states_t_l_, n_states, batch, s_size); + AOC states_tm1_l(states_tm1_l_, n_states, batch, s_size); + + auto sigmoid = [=](float a) {float tmp=expf(-a); return 1.0f / (1.0f + tmp);}; +#pragma omp parallel for + for(int i=0; i +elemwise_sig(_ref_rnn_common_t::lstm_elemwise) { + AOC ws_gates(ws_gates_, batch, n_gates, s_size); + AOC bias(bias_, n_gates, s_size); + AOC states_t_l(states_t_l_, n_states, batch, s_size); + AOC states_tm1_l(states_tm1_l_, n_states, batch, s_size); + AOC diff_states_t_l( + diff_states_t_l_, n_states + 1, batch, s_size); + AOC diff_states_tp1_l( + diff_states_tp1_l_, n_states + 1, batch, s_size); + AOC diff_states_t_lp1( + diff_states_t_lp1_, n_states + 1, batch, s_size); + + auto one_m_square = [](float a) -> float { return 1.0f - a * a; }; + +#pragma omp parallel for + for (int i = 0; i < batch; i++) { +#pragma omp simd + for (int j = 0; j < s_size; j++) { + float Ct = states_t_l(1, i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 2, j) * dHt; + + float dG0 = states_tm1_l(1, i, j) + * logistic_bwd(dCt, ws_gates(i, 0, j)); + float dG1 + = ws_gates(i, 3, j) * logistic_bwd(dCt, ws_gates(i, 1, j)); + float dG2 = logistic_bwd(tanhCt * dHt, ws_gates(i, 2, j)); + float dG3 = ws_gates(i, 1, j) * tanh_bwd(dCt, ws_gates(i, 3, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 0, j); + ; + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + } +} + +template +gemm_sig(_ref_rnn_common_t::packed_gemm) { +#if USE_MKL_PACKED_GEMM + cblas_sgemm_compute(CblasColMajor, CblasPacked, + is_B_trans ? CblasTrans : CblasNoTrans, m, n, k, a_, m, b_, + is_B_trans ? n : k, beta, c_, m); +#else + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(a_); + UNUSED(b_); + UNUSED(c_); + UNUSED(is_B_trans); + UNUSED(beta); + assert(!"packed gemm is disabled"); +#endif +} + +template +gemm_sig(_ref_rnn_common_t::gemm) { + cblas_sgemm(CblasColMajor, CblasNoTrans, + is_B_trans ? CblasTrans : CblasNoTrans, m, n, k, 1.0f, a_, m, b_, + is_B_trans ? n : k, beta, c_, m); +} + +/// @todo template this function on fwd or bwd, if the overhead +/// to pass argument for empty function is too big +template <> +cell_execution_sig(_ref_rnn_common_t::cell_execution) { + (this->*gemm_input_func)(n_gates * s_size, batch, x_size, w_input_, + states_t_lm1_, ws_gates_, false, 0.0f); + (this->*gemm_state_func)(n_gates * s_size, batch, h_size, w_state_, + states_tm1_l_, ws_gates_, false, 1.0f); + + (this->*elemwise_func)(s_size, batch, n_states, n_gates, ws_gates_, + states_t_l_, states_t_lm1_, states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_); +} + +template <> +cell_execution_sig(_ref_rnn_common_t::cell_execution) { + (this->*elemwise_func)(s_size, batch, n_states, n_gates, ws_gates_, + states_t_l_, states_t_lm1_, states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_); + + /// bwd by data on the cell + (this->*gemm_state_func)(h_size, batch, n_gates * s_size, w_state_, + ws_gates_, diff_states_t_l_, false, 0.0f); + (this->*gemm_input_func)(h_size, batch, n_gates * s_size, w_input_, + ws_gates_, diff_states_t_l_ + n_states * (batch * s_size), false, + 0.0f); + + /// bwd by weights on the cell + gemm(n_gates * x_size, s_size, batch, ws_gates_, states_t_lm1_, + diff_w_input_, true, 1.0f); + gemm(n_gates * h_size, s_size, batch, ws_gates_, states_tm1_l_, + diff_w_state_, true, 1.0f); + +/// bwd by bias we just accumulate diffs from the gates +#if (_OPENMP == 201307) +#pragma omp parallel for simd collapse(2) +#else +#pragma omp parallel for collapse(2) ///@todo block k on simd-width +#endif + for (int i = 0; i < n_gates; i++) + for (int k = 0; k < s_size; k++) + for (int j = 0; j < batch; j++) + diff_bias_[i * s_size + k] + += ws_gates_[(j * n_gates + i) * s_size + k]; +} + +//*************** Grid computations strategy: linear ***************// +template +grid_execution_sig(_ref_rnn_common_t::linear_execution) { + AOC ws_states(ws_states_, n_layer + 1, n_direction, n_iter + 1, + n_states * batch * s_size); + AOC ws_diff_states(ws_diff_states_, n_layer + 1, n_direction, + n_iter + 1, (n_states + 1) * batch * s_size); + AOC ws_gates( + ws_gates_, n_layer, n_direction, n_iter, n_gates * batch * s_size); + AOC weights_input(weights_input_, n_layer, n_direction); + AOC weights_states(weights_states_, n_layer, n_direction); + AOC bias(bias_, n_layer, n_direction, n_gates * s_size); + AOC diff_weights_layer(diff_weights_layer_, n_layer, n_direction, + x_size * n_gates * s_size); + AOC diff_weights_iter(diff_weights_iter_, n_layer, n_direction, + h_size * n_gates * s_size); + AOC diff_bias(diff_bias_, n_layer, n_direction, n_gates * s_size); + + // We run the grid of computation + for (int dir = 0; dir < n_direction; dir++) { + for (int j = 0; j < n_layer; j++) { + for (int i = 0; i < n_iter; i++) { + int lay, iter; + if (aprop == prop_kind::forward) { + lay = j; + iter = i; + } else { // backward + lay = n_layer - j - 1; + iter = n_iter - i - 1; + } + cell_execution(s_size, x_size, h_size, batch, n_gates, n_states, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, iter, 0)), + weights_input(lay, dir), weights_states(lay, dir), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay, dir, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0))); + } + } + } +} + +#if 0 +//************* Grid computations strategy: wavefront **************// + +/* + // To cover n_iter > n_layer and n_iter < n_layer + min_dim = min(n_layer, n_iter) + max_dim = max(n_layer, n_iter) + and we assume that i refers to the max_dim dimension and j to the min_dim dimension + + We compute the the wavefront using 3 loop nests, each nest having 2 loops: + - one for the head of the form loop on n_layer, and loop on n_elem in wave + for (int i = 0; i < min_dim - 1; i++) + for(int j = 0; j < i+1; j++) + - one for the body: + for (int i = 0; i < max_dim - min_dim + 1; i++) + for(int j = 0; j < min_dim; j++) + - one for the tail + for (int i = min_dim; i > 0 ; i--) + for(int j = 0; j < i; j++) + Here, we define classes for each of the wavefront direction to compute + the coordinates of the recurrent cells when running a wavefront execution + */ + +typedef enum wavefront_loop_index_ { + head, + body, + tail +} wavefront_loop_index; + +struct wavefront_indexer { + wavefront_indexer(int dim) + : dim_(dim){}; + virtual int get(wavefront_loop_index idx,int i, int j) const; +protected: + int dim_; +}; + +// bottom to top or left to right maxdim +struct wi_b2t_l2r_maxdim : wavefront_indexer { + int get(wavefront_loop_index idx, int i, int j) const override { + switch(idx){ + case head: return i - j; + case body: return i - j; + case tail: return dim_ - 1 - j; + default: return -1; + } + } +}; + +// bottom to top or left to right mindim +struct wi_b2t_l2r_mindim : wavefront_indexer { + int get(wavefront_loop_index idx, int i , int j) const override { + switch(idx){ + case head: return j; + case body: return j; + case tail: return dim_ - i + j; + default: return -1; + } + } +}; + +template +struct reversed_indexer : wavefront_indexer { + reversed_indexer(int dim) : wavefront_indexer(dim), + wd(original_indexer(dim)){} + + int get(wavefront_loop_index idx, int i, int j) const override { + switch(idx){ + case head: return dim_ - 1 - wd.head(i,j); + case body: return dim_ - 1 - wd.body(i,j); + case tail: return dim_ - 1 - wd.tail(i,j); + default: return -1; + } + } + +private: + original_indexer wd; +}; + +// top to bottom or right left maxdim and mindim +using wi_t2b_r2l_maxdim = reversed_indexer; +using wi_t2b_r2l_mindim = reversed_indexer; + +template +grid_execution_sig(_ref_rnn_common_t::wavefront_execution){// (int s_size, int x_size, + // int h_size, int batch, + // int n_layer, int n_direction, int n_iter, + // int n_gates, int n_states, + // const float **weights_input_, //[n_gates*s_size][x_size], + // const float **weights_states_, //[n_gates*s_size][s_size], + // const float *bias_, //[n_gates][s_size], + // float *ws_, //[n_layer+1][n_direction][n_iter+1][n_states][batch][s_size], + // float *gates_){ //[n_layer][n_direction][n_iter][batch][n_gates][s_size]) { + + AOC ws(ws_, n_layer + 1, n_direction, n_iter + 1, n_states * batch * s_size); + AOC gates(gates_, n_layer, n_direction, n_iter, n_gates * batch * s_size); + AOC weights_input(weights_input_, n_layer, n_direction); + AOC weights_states(weights_states_, n_layer, n_direction); + AOC bias(bias_, n_layer, n_gates * s_size); + // Setup the indexers: we have to check directions and if max_dim or min_dim + bool is_niter_maxdim = n_iter >= n_layer; + wavefront_indexer wi_maxdim = (is_niter_maxdim) + ? (((exec_dir == b2t_l2r) || (exec_dir == t2b_l2r)) //niter is maxdim, we look for l2r + ? (wavefront_indexer) wi_b2t_l2r_maxdim(n_iter) + : (wavefront_indexer) wi_t2b_r2l_maxdim(n_iter)) + : (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is maxdim, we look for b2t + ? (wavefront_indexer) wi_b2t_l2r_maxdim(n_layer) + : (wavefront_indexer) wi_t2b_r2l_maxdim(n_layer)); + + wavefront_indexer wi_mindim = (!is_niter_maxdim) + ? (((exec_dir == b2t_l2r) || (exec_dir == t2b_l2r)) //niter is mindim, we look for l2r + ? (wavefront_indexer) wi_b2t_l2r_mindim(n_iter) + : (wavefront_indexer) wi_t2b_r2l_mindim(n_iter)) + : (((exec_dir == b2t_l2r) || (exec_dir == b2t_r2l)) //nlayer is mindim, we look for b2t + ? (wavefront_indexer) wi_b2t_l2r_mindim(n_layer) + : (wavefront_indexer) wi_t2b_r2l_mindim(n_layer)); + + // auto get_offset = [=](wavefront_loop_index idx, int i, int j){ + // int dim_min = wi_mindim.get(idx, i,j); + // int dim_max = wi_maxdim.get(idx, i,j); + // int offset = (is_niter_maxdim) + // ? dim_min*n_iter + dim_max + // : dim_max*n_iter + dim_min; + // }; + +#define get_lay_n_iter(idx, i, j) \ + do { \ + int dim_min = wi_mindim.get(idx, i, j); \ + int dim_max = wi_maxdim.get(idx, i, j); \ + if (is_niter_maxdim) { \ + lay = dim_min; \ + iter = dim_max; \ + } else { \ + lay = dim_max; \ + iter = dim_min; \ + } \ + } while (0) + + int min_dim = is_niter_maxdim ? n_layer : n_iter; + int max_dim = is_niter_maxdim ? n_iter :n_layer; + int lay, iter; + for (int i = 0; i < min_dim - 1; i++) + for(int j = 0; j < i+1; j++){ + get_lay_n_iter(head,i,j); + cell_execution(s_size, x_size, h_size, batch, + n_gates, n_states, + &(ws(lay, iter, 0)), weights_input(lay - 1, 0), + weights_states(lay - 1, 0), &(bias(lay-1, 0)), + &(ws(lay - 1, iter, 0)), &(ws(lay, iter - 1, 0)), &(gates(lay-1, iter-1, 0))); + } + for (int i = min_dim - 1; i < max_dim; i++) + for(int j = 0; j < min_dim; j++){ + get_lay_n_iter(body,i,j); + } + for (int i = min_dim - 1; i > 0 ; i--) + for(int j = 0; j < i; j++){ + get_lay_n_iter(tail,i,j); + } + +#undef get_lay_n_iter +} +#endif +//********* GRID computations strategy: utility functions **********// + +template <> +void _ref_rnn_common_t::copy_init_layer(bool lr, bool rl, + int n_layer, int n_direction, int n_iter, int batch, int x_size, + int n_states, float *ws_states_, float *ws_diff_states_, + const float *xt_, const float *diff_dst_layer_) { + AOC ws_states( + ws_states_, n_direction, n_iter + 1, n_states * batch * x_size); + auto xt_d = memory_desc_wrapper(conf_.src_pd(0)); + +#pragma omp parallel for + for (int it = 0; it < n_iter; it++) { + auto xxt = xt_ + xt_d.blk_off(it); + if (lr) + array_copy(&(ws_states(0, it + 1, 0)), xxt, batch * x_size); + if (rl) + array_copy(&(ws_states(n_direction - 1, n_iter - it, 0)), xxt, + batch * x_size); + } +} + +template <> +void _ref_rnn_common_t::copy_init_layer(bool lr, bool rl, + int n_layer, int n_direction, int n_iter, int batch, int x_size, + int n_states, float *ws_states_, float *ws_diff_states_, + const float *xt_, const float *diff_dst_layer_) { + AOC ws_diff_states(ws_diff_states_, n_layer + 1, n_direction, + n_iter + 1, (n_states + 1), batch, x_size); + auto diff_dst_layer_d = memory_desc_wrapper(conf_.diff_dst_pd(0)); + + switch (conf_.direction()) { + case mkldnn_bidirectional_concat: +#pragma omp parallel for collapse(2) + for (int it = 0; it < n_iter; it++) { + for (int b = 0; b < batch; b++) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < x_size; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; + ws_diff_states(n_layer, 1, it, n_states, b, s) + = diff_dst_layer_x[x_size + s]; + } + } + } + break; + case mkldnn_bidirectional_sum: +#pragma omp parallel for collapse(2) + for (int it = 0; it < n_iter; it++) { + for (int b = 0; b < batch; b++) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < x_size; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; + ws_diff_states(n_layer, 1, it, n_states, b, s) + = diff_dst_layer_x[s]; + } + } + } + break; + default: // assumes default is always unidirectional +#pragma omp parallel for collapse(2) + for (int it = 0; it < n_iter; it++) { + for (int b = 0; b < batch; b++) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < x_size; s++) { + ws_diff_states(n_layer, 0, it, n_states, b, s) + = diff_dst_layer_x[s]; + } + } + } + break; + } +} + +template <> +void _ref_rnn_common_t::copy_init_iter(int n_layer, + int n_direction, int n_states, int batch, int h_size, int n_iter, + float *ws_states_, float *ws_diff_states_, const float *firstit_states_, + const float *diff_dst_iter_) { + AOC ws_states(ws_states_, n_layer + 1, n_direction, n_iter + 1, + n_states, batch, h_size); + auto firstit_states_d = memory_desc_wrapper(conf_.src_pd(1)); + if (firstit_states_) { +#pragma omp parallel for collapse(2) + for (int lay = 0; lay < n_layer; lay++) + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int b = 0; b < batch; ++b) { + array_copy(&(ws_states(lay + 1, dir, 0, state, b, 0)), + firstit_states_ + + firstit_states_d.blk_off( + lay, dir, state, b), + h_size); + } + } else { +#pragma omp parallel for collapse(2) + for (int lay = 0; lay < n_layer; lay++) + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int i = 0; i < batch; i++) + for (int j = 0; j < h_size; j++) + ws_states(lay + 1, dir, 0, state, i, j) = 0.0f; + } +} + +template <> +void _ref_rnn_common_t::copy_init_iter(int n_layer, + int n_direction, int n_states, int batch, int h_size, int n_iter, + float *ws_states_, float *ws_diff_states_, const float *firstit_states_, + const float *diff_dst_iter_) { + AOC ws_diff_states(ws_diff_states_, n_layer + 1, n_direction, + n_iter + 1, n_states + 1, batch, h_size); + auto diff_dst_iter_d = memory_desc_wrapper(conf_.diff_dst_pd(1)); + if (diff_dst_iter_) { +#pragma omp parallel for collapse(4) + for (int lay = 0; lay < n_layer; lay++) + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int b = 0; b < batch; ++b) { + array_copy(&(ws_diff_states( + lay, dir, n_iter, state, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + h_size); + } + } else { +#pragma omp parallel for collapse(4) + for (int lay = 0; lay < n_layer; lay++) + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int i = 0; i < batch; i++) + for (int j = 0; j < h_size; j++) + ws_diff_states(lay, dir, n_iter, state, i, j) + = 0.0f; + } +} + +template <> +void _ref_rnn_common_t::copy_res_layer(bool lr, bool rl, + int n_layer, int n_direction, int n_iter, int batch, + int n_output_features, int s_size, int n_states, + mkldnn_rnn_direction_t direction, float *dst_layer_, + float *diff_src_layer, const float *ws_states_, + const float *ws_diff_states_) { + auto dst_layer_d = memory_desc_wrapper(conf_.dst_pd(0)); + AOC ws_states(ws_states_, n_layer + 1, n_direction, + n_iter + 1, n_states, batch, s_size); +#pragma omp parallel for collapse(2) + for (int it = 0; it < n_iter; it++) { + for (int b = 0; b < batch; b++) { + int dir = 0; + if (lr) { + for (int s = 0; s < s_size; s++) + dst_layer_[dst_layer_d.blk_off(it, b, dir * s_size + s)] + = ws_states(n_layer, dir, it + 1, 0, b, s); + dir = 1; + } + if (rl) { + for (int s = 0; s < s_size; s++) + switch (direction) { + case mkldnn_bidirectional_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] += ws_states( + n_layer, dir, n_iter - it, 0, b, s); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * s_size + s)] + = ws_states(n_layer, dir, n_iter - it, 0, b, s); + } + } + } + } +} + +template <> +void _ref_rnn_common_t::copy_res_layer(bool lr, bool rl, + int n_layer, int n_direction, int n_iter, int batch, + int n_output_features, int s_size, int n_states, + mkldnn_rnn_direction_t direction, float *dst_layer_, + float *diff_src_layer_, const float *ws_states_, + const float *ws_diff_states_) { + auto diff_src_layer_d = memory_desc_wrapper(conf_.diff_src_pd(0)); + AOC ws_diff_states(ws_diff_states_, n_layer + 1, + n_direction, n_iter + 1, n_states + 1, batch, s_size); +#pragma omp parallel for collapse(2) + for (int it = 0; it < n_iter; it++) { + for (int b = 0; b < batch; b++) { + int dir = 0; + for (int s = 0; s < s_size; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (direction + == mkldnn_unidirectional_right2left) ? + n_iter - 1 - it : + it, + b, dir * s_size + s); + float res = ws_diff_states(0, 0, it, n_states, b, s); + if (n_direction - 1) + res += ws_diff_states( + 0, 1, n_iter - 1 - it, n_states, b, s); + dst_addr[0] = res; + } + } + } +} + +template <> +void _ref_rnn_common_t::copy_res_iter(int n_layer, + int n_direction, int n_states, int batch, int s_size, int n_iter, + float *dst_iter_, float *diff_src_iter_, const float *ws_states_, + const float *ws_diff_states_) { + auto dst_iter_d = memory_desc_wrapper(conf_.dst_pd(1)); + AOC ws_states(ws_states_, n_layer + 1, n_direction, + n_iter + 1, n_states, batch, s_size); + if (dst_iter_) { +#pragma omp parallel for collapse(4) + for (int lay = 0; lay < n_layer; lay++) { + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int b = 0; b < batch; b++) + for (int s = 0; s < s_size; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, state, b, s)] + = ws_states( + lay + 1, dir, n_iter, state, b, s); + } + } + } +} + +template <> +void _ref_rnn_common_t::copy_res_iter(int n_layer, + int n_direction, int n_states, int batch, int s_size, int n_iter, + float *dst_iter_, float *diff_src_iter_, const float *ws_states_, + const float *ws_diff_states_) { + auto diff_src_iter_d = memory_desc_wrapper(conf_.diff_src_pd(1)); + AOC ws_diff_states(ws_diff_states_, n_layer + 1, + n_direction, n_iter + 1, n_states + 1, batch, s_size); + if (diff_src_iter_) { +#pragma omp parallel for collapse(4) + for (int lay = 0; lay < n_layer; lay++) { + for (int dir = 0; dir < n_direction; dir++) + for (int state = 0; state < n_states; state++) + for (int b = 0; b < batch; b++) + for (int s = 0; s < s_size; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, 0, state, b, s); + } + } + } +} + +template +packing_sig(_ref_rnn_common_t::pack_weights) { +#if USE_MKL_PACKED_GEMM + AOC w( + w_, n_layer, n_direction, n_gates * OC_size * IC_size); + AOC weights(weights_, n_layer, n_direction); + int m = 0, n = 0, k = 0, ldA = 0; + auto transA = CblasNoTrans; + if (aprop == prop_kind::forward) { + m = n_gates * OC_size; + n = batch; + k = IC_size; + transA = CblasNoTrans; + ldA = m; + } + if (aprop == prop_kind::backward) { + m = IC_size; + n = batch; + k = n_gates * OC_size; + transA = CblasTrans; + ldA = k; + } + for (int i = 0; i < n_layer; i++) { + for (int d = 0; d < n_direction; d++) { + weights(i, d) = cblas_sgemm_alloc(CblasAMatrix, m, n, k); + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, transA, m, n, k, 1.0f, + &(w(i, d, 0)), ldA, weights(i, d)); + } + } +#else + UNUSED(n_layer); + UNUSED(n_direction); + UNUSED(n_weights); + UNUSED(n_gates); + UNUSED(batch); + UNUSED(OC_size); + UNUSED(IC_size); + UNUSED(weights_); + UNUSED(w_); + assert(!"packed gemm is disabled"); +#endif +} + +template +packing_sig(_ref_rnn_common_t::no_pack_weights) { + AOC w( + w_, n_layer, n_direction, n_gates * OC_size * IC_size); + AOC weights(weights_, n_layer, n_direction); + for (int i = 0; i < n_layer; i++) { + for (int d = 0; d < n_direction; d++) { + weights(i, d) = (float *)&(w(i, d, 0)); + } + } +} + +template +free_packed_sig(_ref_rnn_common_t::free_packed_weights) { +#if USE_MKL_PACKED_GEMM + for (int i = 0; i < n_layer; i++) { + cblas_sgemm_free(weights_[i]); + } +#else + UNUSED(n_layer); + UNUSED(weights_); + assert(!"packed gemm is disabled"); +#endif +} + +template +free_packed_sig(_ref_rnn_common_t::free_no_packed_weights) { + UNUSED(n_layer); + UNUSED(weights_); +} + +//********************* Execution function *********************// +template +void _ref_rnn_common_t::execute_() { + int n_layer = conf_.L(); + int n_direction = conf_.D(); + int n_iter = conf_.T(); + int n_gates = conf_.G(); + int n_states = conf_.S(); + int n_weights_input = conf_.SLC(); + int n_weights_state = conf_.SIC(); + int batch = conf_.MB(); + int x_size = conf_.SLC(); + int h_size = conf_.DIC(); + int s_size = conf_.DIC(); + + bool is_fwd = aprop == prop_kind::forward; + + int input_idx = 0; + int output_idx = 0; + auto input + = reinterpret_cast(this->input_memory(input_idx++)); + auto states = conf_.with_src_iter() ? + reinterpret_cast(this->input_memory(input_idx++)) : + nullptr; + auto w_input + = reinterpret_cast(this->input_memory(input_idx++)); + auto w_state + = reinterpret_cast(this->input_memory(input_idx++)); + auto bias = conf_.with_bias() ? + reinterpret_cast(this->input_memory(input_idx++)) : + nullptr; + + auto dst_last_layer = is_fwd ? + reinterpret_cast(this->memory(output_idx++)) : + const_cast(reinterpret_cast( + this->input_memory(input_idx++))); + auto dst_last_iter = conf_.with_dst_iter() ? + (is_fwd ? reinterpret_cast(this->memory(output_idx++)) : + const_cast(reinterpret_cast( + this->input_memory(input_idx++)))) : + nullptr; + + auto diff_dst_layer = is_fwd ? + nullptr : + reinterpret_cast(this->input_memory(input_idx++)); + auto diff_dst_iter = is_fwd || !conf_.with_src_iter() ? + nullptr : + reinterpret_cast(this->input_memory(input_idx++)); + + // if no workspace was provided we use the scratchpad + if (use_scratchpad_) { + ws_gates_ = ((float *)scratchpad_->get()); + ws_states_ = ((float *)scratchpad_->get()) + ws_states_offset_; + ws_diff_states_ + = ((float *)scratchpad_->get()) + ws_diff_states_offset_; + } else { + float *ws_ptr = is_fwd ? + reinterpret_cast(this->memory(output_idx++)) : + const_cast(reinterpret_cast( + this->input_memory(input_idx++))); + ws_gates_ = ws_ptr + ws_gates_offset_; + ws_states_ = ws_ptr + ws_states_offset_; + ws_diff_states_ = ws_ptr + ws_diff_states_offset_; + } + + auto diff_src_layer = is_fwd ? + nullptr : + reinterpret_cast(this->memory(output_idx++)); + auto diff_src_iter = is_fwd || !conf_.with_src_iter() ? + nullptr : + reinterpret_cast(this->memory(output_idx++)); + auto diff_weights_layer = is_fwd ? + nullptr : + reinterpret_cast(this->memory(output_idx++)); + auto diff_weights_iter = is_fwd ? + nullptr : + reinterpret_cast(this->memory(output_idx++)); + auto diff_bias = is_fwd || !conf_.with_bias() ? + nullptr : + reinterpret_cast(this->memory(output_idx++)); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states_, 0.0f, conf_.ws_diff_states_size()); + + // TODO: implement without copies + bool is_lr = !one_of(exec_dir, b2t_r2l, t2b_r2l); + bool is_rl = !one_of(exec_dir, b2t_l2r, t2b_l2r); + + // we pack the weights if we are using the packed API + (this->*weights_state_pack_func)(n_layer, n_direction, n_weights_state, + n_gates, batch, s_size, h_size, ptr_wei_state_, w_state); + (this->*weights_input_pack_func)(n_layer, n_direction, n_weights_input, + n_gates, batch, s_size, x_size, ptr_wei_input_, w_input); + + // we first need to copy the initial states and input into ws + copy_init_layer(is_lr, is_rl, n_layer, n_direction, n_iter, batch, x_size, + n_states, ws_states_, ws_diff_states_, input, diff_dst_layer); + copy_init_iter(n_layer, n_direction, n_states, batch, h_size, n_iter, + ws_states_, ws_diff_states_, states, diff_dst_iter); + + // run the execution on the grid + (this->*grid_computation)(s_size, x_size, h_size, batch, n_layer, + n_direction, n_iter, n_gates, n_states, ptr_wei_input_, + ptr_wei_state_, (float *)bias, ws_states_, ws_diff_states_, + ws_gates_, diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + copy_res_layer(is_lr, is_rl, n_layer, n_direction, n_iter, batch, + n_output_features, s_size, n_states, conf_.direction(), + dst_last_layer, diff_src_layer, ws_states_, ws_diff_states_); + copy_res_iter(n_layer, n_direction, n_states, batch, s_size, n_iter, + dst_last_iter, diff_src_iter, ws_states_, ws_diff_states_); + + // We free the packed weights if they were packed internally + (this->*weights_state_free_packed_func)(n_layer, ptr_wei_state_); + (this->*weights_input_free_packed_func)(n_layer, ptr_wei_input_); +}; + +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; +} +} +} diff --git a/src/cpu/ref_rnn.hpp b/src/cpu/ref_rnn.hpp new file mode 100644 index 00000000000..2bee7be9ecb --- /dev/null +++ b/src/cpu/ref_rnn.hpp @@ -0,0 +1,377 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include + +#include "c_types_map.hpp" +#include "cpu_engine.hpp" +#include "cpu_rnn_pd.hpp" +#include "scratchpad.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define elemwise_sig(f) \ + void f(int s_size, int batch, int n_states, int n_gates, float *ws_gates_, \ + float *states_t_l_, float *states_t_lm1_, float *states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, const float *bias_) + +#define cell_execution_sig(f) \ + void f(int s_size, int x_size, int h_size, int batch, int n_gates, \ + int n_states, float *states_t_l_, float *diff_states_t_l_, \ + const float *w_input_, const float *w_state_, const float *bias_, \ + float *states_t_lm1_, float *states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_input_, float *diff_w_state_, float *diff_bias_, \ + float *ws_gates_) + +#define grid_execution_sig(f) \ + void f(int s_size, int x_size, int h_size, int batch, int n_layer, \ + int n_direction, int n_iter, int n_gates, int n_states, \ + float **weights_input_, float **weights_states_, \ + const float *bias_, float *ws_states_, float *ws_diff_states_, \ + float *ws_gates_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) + +#define gemm_sig(f) \ + void f(int m, int n, int k, const float *a_, float *b_, float *c_, \ + bool is_B_trans, float beta) + +#define packing_sig(f) \ + void f(int n_layer, int n_direction, int n_weights, int n_gates, \ + int batch, int OC_size, int IC_size, float **weights_, \ + const float *w_) + +#define free_packed_sig(f) void f(int n_layer, float **weights_) + +template +float activation(float s, float alpha, float cliping, float dd); + +template +struct _ref_rnn_common_t : public cpu_primitive_t { + using class_name = _ref_rnn_common_t; + typedef enum execution_direction_ { + b2t_l2r, + b2t_r2l, + b2t_bi_concat, + b2t_bi_sum, + t2b_l2r, + t2b_r2l, + t2b_bi_concat, + t2b_bi_sum + } execution_direction; + typedef elemwise_sig((class_name::*elemwise_f)); + typedef cell_execution_sig((class_name::*cell_execution_f)); + typedef grid_execution_sig((class_name::*grid_execution_f)); + + typedef gemm_sig((class_name::*gemm_t)); + typedef packing_sig((class_name::*packing_t)); + typedef free_packed_sig((class_name::*free_packed_t)); + + using base_pd_t = + typename utils::conditional::type; + + struct pd_t : public base_pd_t { + pd_t(engine_t *engine, const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_pd) + : base_pd_t(engine, adesc, attr, hint_pd) {} + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace memory_format; + assert(this->engine()->kind() == engine_kind::cpu); + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + bool ok = true +#if !defined(USE_CBLAS) + && false +#endif + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru) + && implication(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && implication(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && this->set_default_params() == status::success; + if (!ok) + return status::unimplemented; + + ok = ok && utils::one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru); + + /// @todo check data layouts for all input tensors + ok = ok && this->desc()->src_layer_desc.format == tnc + && this->desc()->dst_layer_desc.format == tnc; + + ok = ok && this->with_bias(); + switch (aprop) { + case (prop_kind::forward): + ok = ok && utils::one_of(this->desc()->prop_kind, + forward_training, forward_inference); + ok = ok && utils::one_of( + this->desc()->weights_layer_desc.format, any, + ldigo, ldigo_p) + && utils::one_of(this->desc()->weights_iter_desc.format, + any, ldigo, ldigo_p); + break; + case (prop_kind::backward): + ok = ok && utils::one_of(this->desc()->prop_kind, backward); + ok = ok && utils::one_of( + this->desc()->weights_layer_desc.format, any, + ldgoi, ldgoi_p) + && utils::one_of(this->desc()->weights_iter_desc.format, + any, ldgoi, ldgoi_p); + break; + default: ok = false; + } + + // initialize the workspace_pd + dims_t ws_dims = { (dim_t)this->get_ws_size() }; + memory_desc_t ws_d; + mkldnn_memory_desc_init( + &ws_d, 1, ws_dims, impl::data_type::f32, memory_format::x); + this->ws_pd_ = cpu_memory_t::pd_t(this->engine(), &ws_d); + return ok ? status::success : status::unimplemented; + } + }; + + _ref_rnn_common_t(const pd_t *pd, const input_vector &inputs, + const output_vector &outputs) + : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if x_size != s_size and h_size != s_size + /// respectively + + memory_format_t packed_format; + switch (aprop) { + case prop_kind::forward_inference: + case prop_kind::forward_training: + packed_format = memory_format::ldigo_p; + break; + case prop_kind::backward: packed_format = memory_format::ldgoi_p; break; + default: assert(false); + } + + auto set_pack_funcs = [this](bool packed_gemm, gemm_t &g, bool pack_w, + packing_t &p, free_packed_t &f) { + g = packed_gemm ? &class_name::packed_gemm : &class_name::gemm; + p = pack_w ? &class_name::pack_weights : + &class_name::no_pack_weights; + f = pack_w ? &class_name::free_packed_weights : + &class_name::free_no_packed_weights; + }; + + const bool weights_pack_cond = USE_MKL_PACKED_GEMM && conf_.T() > 1; + const bool is_weights_state_packed = USE_MKL_PACKED_GEMM + && conf_.desc()->weights_iter_desc.format == packed_format; + + set_pack_funcs(weights_pack_cond || is_weights_state_packed, + gemm_state_func, weights_pack_cond && !is_weights_state_packed, + weights_state_pack_func, weights_state_free_packed_func); + + const bool is_weights_input_packed = USE_MKL_PACKED_GEMM + && conf_.desc()->weights_layer_desc.format == packed_format; + + set_pack_funcs(weights_pack_cond || is_weights_input_packed, + gemm_input_func, weights_pack_cond && !is_weights_input_packed, + weights_input_pack_func, weights_input_free_packed_func); + + switch (conf_.cell_kind()) { + case alg_kind::vanilla_lstm: + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + elemwise_func = &class_name::rnn_elemwise; + switch (conf_.activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation; + break; + default: break; + } + break; + + // case alg_kind::vanilla_gru: + // elemwise_func = &class_name::gru_elemwise; break; + default: break; + } + + n_output_features + = (conf_.direction() == mkldnn_bidirectional_concat) ? 2 : 1; + switch (conf_.direction()) { + case mkldnn_unidirectional_left2right: exec_dir = b2t_l2r; break; + case mkldnn_unidirectional_right2left: exec_dir = b2t_r2l; break; + case mkldnn_bidirectional_concat: exec_dir = b2t_bi_concat; break; + case mkldnn_bidirectional_sum: exec_dir = b2t_bi_sum; break; + default: break; + } + + /// @todo put a heuristic to choose between linear execution and + /// wavefront + grid_computation = &class_name::linear_execution; + + conf_.set_ws_offsets( + ws_gates_offset_, ws_states_offset_, ws_diff_states_offset_); + + // we need to allocate memory for: + // - the states to compute a pass. + // - the intermediate results from the gates. + // - the diff_states to compute the backward pass (training only) + // These should be allocated on scratchpad if fwd inference + // or on a workspace provided by the user for training. + /// @todo shall we require the workspace for training or make it + /// optional? + + // if no worskpace is provided on forward, we use a scratchpad + // NOTE: here we use a large worskpace for simplicity: + // - for states: + // - TODO: allocate only n_iter * s_size + s_size for linear execution + // (inference) + // - TODO: allocate only n_layer_wav * (2*s_size) for wavefront + // execution (inference) + // - for gates: + // - TODO: allocate only batch * n_gates * s_size for linear execution + // (inference) + // = TODO: allocate only n_layer_wav * batch * n_gates * s_size for + // wavefront execution (inference) + + switch (conf_.desc()->prop_kind) { + case prop_kind::forward_inference: + use_scratchpad_ = (memory(conf_.ws_idx()) == nullptr); + break; + case prop_kind::forward_training: + use_scratchpad_ = (memory(conf_.ws_idx()) == nullptr); + assert(use_scratchpad_ == false); + break; + case prop_kind::backward: + use_scratchpad_ = (input_memory(conf_.ws_idx()) == nullptr); + assert(use_scratchpad_ == false); + break; + default: assert(!"invalid prop_kind"); + } + + if (use_scratchpad_) { + scratchpad_ + = create_scratchpad(conf_.get_ws_size() * sizeof(float)); + } + + int ptr_wei_sz = conf_.L() * conf_.D(); + ptr_wei_input_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64); + ptr_wei_state_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64); + } + ~_ref_rnn_common_t() { + if (use_scratchpad_) + delete scratchpad_; + free(ptr_wei_input_); + free(ptr_wei_state_); + } + + // typedef typename prec_traits::type data_t; + + virtual void execute(event_t *e) { + execute_(); + e->set_state(event_t::ready); + } + +private: + void execute_(); + grid_execution_sig(linear_execution); + // grid_execution_sig(wavefront_execution); + cell_execution_sig(cell_execution); + elemwise_sig(rnn_elemwise); + elemwise_sig(lstm_elemwise); + // elemwise_sig(gru_elemwise); + gemm_sig(gemm); + gemm_sig(packed_gemm); + packing_sig(pack_weights); + packing_sig(no_pack_weights); + free_packed_sig(free_packed_weights); + free_packed_sig(free_no_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(bool lr, bool rl, int n_direction, int n_layer, + int n_iter, int batch, int x_size, int n_states, float *ws_states_, + float *ws_diff_states_, const float *xt_, + const float *diff_dst_layer); + void copy_init_iter(int n_layer, int n_direction, int n_states, int batch, + int h_size, int n_iter, float *ws_states_, float *ws_diff_states_, + const float *firstit_states_, const float *diff_dst_iter); + void copy_res_layer(bool lr, bool rl, int n_layer, int n_direction, + int n_iter, int batch, int n_output_features, int s_size, + int n_states, mkldnn_rnn_direction_t direction, float *dst_layer_, + float *diff_src_layer, const float *ws_states_, + const float *ws_diff_states_); + void copy_res_iter(int n_layer, int n_direction, int n_states, int batch, + int s_size, int n_iter, float *dst_iter_, float *diff_src_iter, + const float *ws_states_, const float *ws_diff_states_); + + pd_t conf_; + bool use_scratchpad_; + scratchpad_t *scratchpad_; + + int ws_gates_offset_; + int ws_states_offset_; + int ws_diff_states_offset_; + + float *ws_gates_; + float *ws_states_; + float *ws_diff_states_; + int n_output_features; + + float **ptr_wei_input_; + float **ptr_wei_state_; + + execution_direction exec_dir; + grid_execution_f grid_computation; + // cell_execution_f cell_execution; + + packing_t weights_input_pack_func; + packing_t weights_state_pack_func; + + gemm_t gemm_input_func; + gemm_t gemm_state_func; + elemwise_f elemwise_func; + + free_packed_t weights_input_free_packed_func; + free_packed_t weights_state_free_packed_func; +}; + +using ref_rnn_fwd_t = _ref_rnn_common_t; +using ref_rnn_bwd_t = _ref_rnn_common_t; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s