diff --git a/include/caffe/layers/mkldnn_layers.hpp b/include/caffe/layers/mkldnn_layers.hpp index aca64c362..dae314208 100644 --- a/include/caffe/layers/mkldnn_layers.hpp +++ b/include/caffe/layers/mkldnn_layers.hpp @@ -141,7 +141,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer, public Layer { shared_ptr input_primitive, bwd_top_diff_primitive; - int32_t num_, width_, height_, channels_; + vector shape_; Dtype eps_, moving_average_fraction_; bool use_weight_bias_, bias_term_, use_global_stats_; int num_stats_batches_; @@ -402,7 +402,7 @@ class MKLDNNReLULayer : public MKLDNNLayer , public NeuronLayer { , reluFwd_pd(), reluBwd_pd() , fwd_top_data_memory(), bwd_bottom_diff_memory() , fwd_bottom_data_primitive(), bwd_top_diff_primitive() - , num_(0), width_(0), height_(0), channels_(0) + , shape_(0) { PERFORMANCE_EVENT_ID_RESET(perf_id_fw_); PERFORMANCE_EVENT_ID_RESET(perf_id_bw_); @@ -431,7 +431,7 @@ class MKLDNNReLULayer : public MKLDNNLayer , public NeuronLayer { MKLDNNPrimitive reluFwd, reluBwd; shared_ptr fwd_top_data_memory, bwd_bottom_diff_memory; shared_ptr fwd_bottom_data_primitive, bwd_top_diff_primitive, bwd_bottom_data_primitive; - int32_t num_, width_, height_, channels_; + vector shape_; PERFORMANCE_EVENT_ID_DECL(perf_id_fw_); PERFORMANCE_EVENT_ID_DECL(perf_id_bw_); @@ -480,7 +480,8 @@ class MKLDNNConcatLayer : public MKLDNNLayer , public Layer { vector split_dims; bool in_place_; - int32_t num_, width_, height_, channels_, num_concats_; + int32_t num_concats_; + vector shape_; int concat_dimension; PERFORMANCE_EVENT_ID_DECL(perf_id_fw_); @@ -537,7 +538,7 @@ class MKLDNNEltwiseLayer : public MKLDNNLayer , public Layer { , eltwiseFwd_pd() , fwd_top_data_memory() , fwd_bottom_data_primitives_() - , num_(0), width_(0), height_(0), channels_(0) + , shape_(0) , num_bottoms_(0) { PERFORMANCE_EVENT_ID_RESET(perf_id_fw_); @@ -573,7 +574,7 @@ class MKLDNNEltwiseLayer : public MKLDNNLayer , public Layer { EltwiseParameter_EltwiseOp op_; vector coeffs_; Blob max_idx_; - int32_t num_, width_, height_, channels_; + vector shape_; int32_t num_bottoms_; bool stable_prod_grad_; diff --git a/src/caffe/layers/mkldnn_batch_norm_layer.cpp b/src/caffe/layers/mkldnn_batch_norm_layer.cpp index d3f76bfbd..b2620dda0 100644 --- a/src/caffe/layers/mkldnn_batch_norm_layer.cpp +++ b/src/caffe/layers/mkldnn_batch_norm_layer.cpp @@ -64,10 +64,9 @@ void MKLDNNBatchNormLayer::LayerSetUp(const vector*>& bottom Layer::LayerSetUp(bottom, top); - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - num_ = bottom[0]->num(); + shape_ = bottom[0]->shape(); + + const int channels = shape_[1]; eps_ = this->layer_param_.batch_norm_param().eps(); use_weight_bias_ = this->layer_param_.batch_norm_param().use_weight_bias(); @@ -77,12 +76,12 @@ void MKLDNNBatchNormLayer::LayerSetUp(const vector*>& bottom if (this->layer_param_.batch_norm_param().has_use_global_stats()) use_global_stats_ = this->layer_param_.batch_norm_param().use_global_stats(); - InitStatsBatchVars(num_); + InitStatsBatchVars(shape_[0]); this->blobs_.resize(3 + (use_weight_bias_ ? 1:0) + (use_weight_bias_ && bias_term_ ? 1:0)); vector sz; - sz.push_back(channels_); + sz.push_back(channels); this->blobs_[0].reset(new Blob(sz)); this->blobs_[1].reset(new Blob(sz)); sz[0]=1; @@ -96,7 +95,7 @@ void MKLDNNBatchNormLayer::LayerSetUp(const vector*>& bottom //Optimization: use the temp blob to combine the scale and shift together. Avoid the additional copies. // Initialize scale and shift combination blob vector scaleshift_blob_shape(1); - scaleshift_blob_shape[0] = 2*channels_; + scaleshift_blob_shape[0] = 2*channels; scaleshift_blob_.reset(new Blob(scaleshift_blob_shape)); //Should initialize the scaleshift_blob_ buffer to 0, because when bias_term_ == false, need to pass zero bias to MKLDNN caffe_set(scaleshift_blob_shape[0], static_cast(0), @@ -111,8 +110,8 @@ void MKLDNNBatchNormLayer::LayerSetUp(const vector*>& bottom if (use_weight_bias_) { // Initialize scale and shift vector scaleshift_shape(1); - scaleshift_shape[0] = channels_; - VLOG(1) << "MKLDNNBatchNormLayer::LayerSetUp: channels_ = " << channels_; + scaleshift_shape[0] = channels; + VLOG(1) << "MKLDNNBatchNormLayer::LayerSetUp: channels = " << channels; this->blobs_[3].reset(new Blob(scaleshift_shape)); this->blobs_[3]->set_cpu_data(scaleshift_blob_->mutable_cpu_data()); @@ -128,8 +127,8 @@ void MKLDNNBatchNormLayer::LayerSetUp(const vector*>& bottom if (bias_term_) { this->blobs_[4].reset(new Blob(scaleshift_shape)); - this->blobs_[4]->set_cpu_data(scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels_)); - this->blobs_[4]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff() + scaleshift_blob_->offset(channels_)); + this->blobs_[4]->set_cpu_data(scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels)); + this->blobs_[4]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff() + scaleshift_blob_->offset(channels)); FillerParameter bias_filler_param(this->layer_param_.batch_norm_param().bias_filler()); if (!this->layer_param_.batch_norm_param().has_bias_filler()) { bias_filler_param.set_type("constant"); @@ -161,17 +160,9 @@ void MKLDNNBatchNormLayer::Reshape(const vector*>& bottom { VLOG(1) << "MKLDNNBatchNormLayer::Reshape: " << this->layer_param_.name(); - this->reshape = (this->width_ == bottom[0]->width() && - this->height_ == bottom[0]->height() && - this->channels_ == bottom[0]->channels() && - this->num_ == bottom[0]->num()) ? false : true; - - this->width_ = bottom[0]->width(); - this->height_ = bottom[0]->height(); - this->num_ = bottom[0]->num(); - this->channels_ = bottom[0]->channels(); + this->reshape = (this->shape_ == bottom[0]->shape()) ? false : true; - InitStatsBatchVars(this->num_); + InitStatsBatchVars(this->shape_[0]); //Fix: should reshape the top blob with the real size of bottom blob //top[0]->Reshape(this->num_, this->channels_, this->height_, this->width_); @@ -194,10 +185,15 @@ void MKLDNNBatchNormLayer::InitBatchNorm(const vector*>& bott if (use_weight_bias_) flags |= use_scale_shift; if (use_global_stats_) flags |= use_global_stats; - int32_t n = this->num_; - int32_t iw = this->width_; - int32_t ih = this->height_; - int32_t ic = this->channels_; + memory::format src_mfmt; + auto tensor_size = this->shape_.size(); + if(tensor_size == 4) { + src_mfmt = memory::format::nchw; + } else if(tensor_size == 5) { + src_mfmt = memory::format::ncdhw; + } + + const int channels = this->shape_[1]; bool bottom_data_is_prv = (const_cast(bottom[0]->prv_data()) != NULL); @@ -216,13 +212,13 @@ void MKLDNNBatchNormLayer::InitBatchNorm(const vector*>& bott usr_mpd = mem_descr->usr_memory_pd(); prv_mpd = mem_descr->prv_memory_pd(); } else { - input_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw)); //MKLDNN batch norm only support 4D memory descriptor! + input_md.reset(new memory::desc({this->shape_}, mpcsn, src_mfmt)); usr_mpd.reset(new memory::primitive_desc(*input_md, cpu_engine)); } output_md = input_md; input_stats_md.reset(new memory::desc(*input_md)); CHECK(input_stats_md->data.ndims > 0 && - input_stats_md->data.dims[0] == this->num_); + input_stats_md->data.dims[0] == this->shape_[0]); input_stats_md->data.dims[0] = stats_batch_size_; // ---- Initialize BatchNorm primitive descriptor ------------- @@ -262,7 +258,7 @@ void MKLDNNBatchNormLayer::InitBatchNorm(const vector*>& bott if (use_weight_bias_) { //For test in train, memory address of blobs_[3] and blobs_[4] will be changed when share data from train net. If the address // of blobs_[3] and blobs_[4] are continued, we will use them immediately, otherwise we will copy them to scaleshift_blob_ in Forward. - if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels_)) == this->blobs_[4]->mutable_cpu_data()){ + if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels)) == this->blobs_[4]->mutable_cpu_data()){ scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->blobs_[3]->mutable_cpu_data())); }else { scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_blob_->mutable_cpu_data())); @@ -309,8 +305,8 @@ void MKLDNNBatchNormLayer::InitBatchNorm(const vector*>& bott LOG(INFO) << "MKLDNN batch norm only support 4D memory descriptor! Use 4D for calculation and reshape to 2D for output!"; #endif vector top_shape; - top_shape.push_back(bottom[0]->num()); - top_shape.push_back(bottom[0]->channels()); + top_shape.push_back(bottom[0]->shape(0)); + top_shape.push_back(bottom[0]->shape(1)); top[0]->Reshape(top_shape); } } @@ -319,12 +315,15 @@ template template shared_ptr MKLDNNBatchNormLayer::GetStatsBatchMemory( shared_ptr > mkldnn_mem, int idx) { - long data_offset = - idx * stats_batch_size_ * this->channels_ * this->width_ * this->height_; + int length = this->shape_[1]; + for(int i=2;ishape_.size();i++) + length *= this->shape_[i]; + + long data_offset = idx * stats_batch_size_ * length; engine cpu_engine = CpuEngine::Instance().get_engine(); shared_ptr stats_md = mkldnn_mem->get_memory_desc(); CHECK(stats_md->data.ndims > 0 && - stats_md->data.dims[0] == this->num_); + stats_md->data.dims[0] == this->shape_[0]); stats_md->data.dims[0] = stats_batch_size_; shared_ptr stats_mpd( new memory::primitive_desc(*stats_md, cpu_engine)); @@ -338,6 +337,8 @@ void MKLDNNBatchNormLayer::InitBatchNormFwdPrimitive(int idx) { input_stats[idx] = GetStatsBatchMemory(fwd_bottom_data, idx); output_stats[idx] = GetStatsBatchMemory(fwd_top_data, idx); + const int channels = this->shape_[1]; + // ---- Create BatchNorm -------------------- if (this->phase_ == TEST && !use_global_stats_) { if (use_weight_bias_) { @@ -353,9 +354,9 @@ void MKLDNNBatchNormLayer::InitBatchNormFwdPrimitive(int idx) { variance_memory[idx].reset(new memory(BatchNormFwd_pd->variance_primitive_desc())); if (use_global_stats_) { - caffe_copy(this->channels_, this->blobs_[0]->cpu_data(), + caffe_copy(channels, this->blobs_[0]->cpu_data(), static_cast(mean_memory[idx]->get_data_handle())); - caffe_copy(this->channels_, this->blobs_[1]->cpu_data(), + caffe_copy(channels, this->blobs_[1]->cpu_data(), static_cast(variance_memory[idx]->get_data_handle())); if (use_weight_bias_) { BatchNormFwd[idx].reset(new batch_normalization_forward(*BatchNormFwd_pd, @@ -398,9 +399,11 @@ void MKLDNNBatchNormLayer::Forward_cpu(const vector*>& bottom // update top that head at prv fwd_top_data->sync_before_write(); - if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels_)) != this->blobs_[4]->mutable_cpu_data()){ - caffe_copy(channels_, this->blobs_[3]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data()); - caffe_copy(channels_, this->blobs_[4]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels_)); + const int channels = this->shape_[1]; + + if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels)) != this->blobs_[4]->mutable_cpu_data()){ + caffe_copy(channels, this->blobs_[3]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data()); + caffe_copy(channels, this->blobs_[4]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels)); } for (int stats_batch_idx = 0; stats_batch_idx < num_stats_batches_; stats_batch_idx++) { @@ -429,11 +432,11 @@ void MKLDNNBatchNormLayer::Forward_cpu(const vector*>& bottom Dtype *variance_buffer_ = (Dtype *)(variance_memory[stats_batch_idx]->get_data_handle()); this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; this->blobs_[2]->mutable_cpu_data()[0] += 1; - caffe_cpu_axpby(this->channels_, Dtype(1), mean_buffer_, + caffe_cpu_axpby(channels, Dtype(1), mean_buffer_, moving_average_fraction_, this->blobs_[0]->mutable_cpu_data()); - int m = bottom[0]->count()/num_stats_batches_/channels_; + int m = bottom[0]->count()/num_stats_batches_/channels; Dtype bias_correction_factor = m > 1 ? Dtype(m)/(m-1) : 1; - caffe_cpu_axpby(this->channels_, bias_correction_factor, + caffe_cpu_axpby(channels, bias_correction_factor, variance_buffer_, moving_average_fraction_, this->blobs_[1]->mutable_cpu_data()); } @@ -450,10 +453,14 @@ void MKLDNNBatchNormLayer::InitBatchNormBwd( { if (std::is_same::value) NOT_IMPLEMENTED; - int32_t n = this->num_; - int32_t w = this->width_; - int32_t h = this->height_; - int32_t c = this->channels_; + memory::format src_mfmt; + auto tensor_size = this->shape_.size(); + if(tensor_size == 4) { + src_mfmt = memory::format::nchw; + } else if(tensor_size == 5) { + src_mfmt = memory::format::ncdhw; + } + unsigned flags = 0; if (use_weight_bias_) flags |= use_scale_shift; @@ -475,16 +482,16 @@ void MKLDNNBatchNormLayer::InitBatchNormBwd( usr_diff_mpd = mem_descr->usr_memory_pd(); prv_diff_mpd = mem_descr->prv_memory_pd(); } else { - top_diff_md.reset(new memory::desc({{n, c, h, w}}, mpcsn, memory::format::nchw)); //MKLDNN batch norm only support 4D memory descriptor! + top_diff_md.reset(new memory::desc({this->shape_}, mpcsn, src_mfmt)); usr_diff_mpd.reset(new memory::primitive_desc(*top_diff_md, cpu_engine)); } top_diff_stats_md.reset(new memory::desc(*top_diff_md)); CHECK(top_diff_stats_md->data.ndims > 0 && - top_diff_stats_md->data.dims[0] == this->num_); + top_diff_stats_md->data.dims[0] == this->shape_[0]); top_diff_stats_md->data.dims[0] = stats_batch_size_; output_stats_md.reset(new memory::desc(output_memory->get_primitive_desc().desc())); CHECK(output_stats_md->data.ndims > 0 && - output_stats_md->data.dims[0] == this->num_); + output_stats_md->data.dims[0] == this->shape_[0]); output_stats_md->data.dims[0] = stats_batch_size_; // ---- Initialize bnrm primitive descriptor ------------- diff --git a/src/caffe/layers/mkldnn_concat_layer.cpp b/src/caffe/layers/mkldnn_concat_layer.cpp index 9299f9cea..f8f90609c 100644 --- a/src/caffe/layers/mkldnn_concat_layer.cpp +++ b/src/caffe/layers/mkldnn_concat_layer.cpp @@ -74,95 +74,28 @@ void MKLDNNConcatLayer::LayerSetUp(const vector*>& bottom, } for (auto i = 1; i < num_concats_; ++i) { - if (concat_dimension == 0) - { - CHECK_EQ(bottom[0]->channels(), bottom[i]->channels()); - CHECK_EQ(bottom[0]->height(), bottom[i]->height()); - CHECK_EQ(bottom[0]->width(), bottom[i]->width()); - break; - } - else if (concat_dimension == 1) - { - CHECK_EQ(bottom[0]->num(), bottom[i]->num()); - if (!concat_param.per_fla_fuse()){ - CHECK_EQ(bottom[0]->height(), bottom[i]->height()); - CHECK_EQ(bottom[0]->width(), bottom[i]->width()); - } - break; - } - else if (concat_dimension == 2) - { - CHECK_EQ(bottom[0]->num(), bottom[i]->num()); - CHECK_EQ(bottom[0]->channels(), bottom[i]->channels()); - CHECK_EQ(bottom[0]->width(), bottom[i]->width()); - break; - } - else if (concat_dimension == 3) - { - CHECK_EQ(bottom[0]->num(), bottom[i]->num()); - CHECK_EQ(bottom[0]->channels(), bottom[i]->channels()); - CHECK_EQ(bottom[0]->height(), bottom[i]->height()); - break; - } + vector bottom0_shape = bottom[0]->shape(); + bottom0_shape[concat_dimension] = 0; + vector bottom_i_shape = bottom[i]->shape(); + bottom_i_shape[concat_dimension] = 0; + CHECK_EQ(bottom0_shape == bottom_i_shape,true); } split_dims.reserve(num_concats_); - if (concat_dimension == 0) - { - num_ = 0; - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - for (auto i = 0; i < num_concats_; ++i) { - CHECK_EQ(dim_src, bottom[i]->shape().size()); - split_dims[i] = bottom[i]->num(); - num_ += split_dims[i]; - } - } - else if (concat_dimension == 1) - { - num_ = bottom[0]->num(); - channels_ = 0; - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - if (concat_param.per_fla_fuse()){ - height_ = 1; - width_ = 1; + shape_ = bottom[0]->shape(); + shape_[concat_dimension] = 0; + if (concat_dimension == 1 && concat_param.per_fla_fuse()) { + for(int i=concat_dimension+1;ishape().size()); - split_dims[i] = bottom[i]->channels()*bottom[i]->height()*bottom[i]->width(); - channels_ += split_dims[i]; + split_dims[i] = bottom[i]->count(concat_dimension); + shape_[concat_dimension] += split_dims[i]; } - } else{ - for (auto i = 0; i < num_concats_; ++i) { - CHECK_EQ(dim_src, bottom[i]->shape().size()); - split_dims[i] = bottom[i]->channels(); - channels_ += split_dims[i]; - } - } - } - else if (concat_dimension == 2) - { - num_ = bottom[0]->num(); - channels_ = bottom[0]->channels(); - height_ = 0; - width_ = bottom[0]->width(); - for (auto i = 0; i < num_concats_; ++i) { - CHECK_EQ(dim_src, bottom[i]->shape().size()); - split_dims[i] = bottom[i]->height(); - height_ += split_dims[i]; - } - } - else if (concat_dimension == 3) - { - num_ = bottom[0]->num(); - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = 0; + } else { for (auto i = 0; i < num_concats_; ++i) { CHECK_EQ(dim_src, bottom[i]->shape().size()); - split_dims[i] = bottom[i]->width(); - width_ += split_dims[i]; + split_dims[i] = bottom[i]->shape(concat_dimension); + shape_[concat_dimension] += split_dims[i]; } } } @@ -172,100 +105,26 @@ void MKLDNNConcatLayer::Reshape(const vector*>& bottom, const vector*>& top) { VLOG(1) << "MKLDNNConcatLayer::Reshape: " << this->layer_param_.name(); const ConcatParameter& concat_param = this->layer_param_.concat_param(); - if (concat_dimension == 0) - { - //Need to re-calculate the shape duo to the change of batch size - num_ = 0; - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - //Also need to reshape the concat dim, in case the concat dim is just be reshaped by batch size - for (auto i = 0; i < num_concats_; ++i) { - split_dims[i] = bottom[i]->num(); - num_ += split_dims[i]; - } - - if (this->channels_ == bottom[0]->channels() && - this->height_ == bottom[0]->height() && - this->width_ == bottom[0]->width()) { - this->reshape = false; - } else { - this->reshape = true; - } - } - else if (concat_dimension == 1) - { - num_ = bottom[0]->num(); - channels_ = 0; - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - if (concat_param.per_fla_fuse()){ - height_ = 1; - width_ = 1; - for (auto i = 0; i < num_concats_; ++i) { - split_dims[i] = bottom[i]->channels()*bottom[i]->height()*bottom[i]->width(); - channels_ += split_dims[i]; - } - if (this->num_ == bottom[0]->num()) { - this->reshape = false; - } else { - this->reshape = true; - } + this->shape_ = bottom[0]->shape(); + this->shape_[concat_dimension] = 0; + this->reshape = false; - } else{ + if (concat_dimension == 1 && concat_param.per_fla_fuse()){ + for(int i=2;ichannels(); - channels_ += split_dims[i]; - } - if (this->num_ == bottom[0]->num() && - this->height_ == bottom[0]->height() && - this->width_ == bottom[0]->width()) { - this->reshape = false; - } else { - this->reshape = true; + split_dims[i] = bottom[i]->count(1); + shape_[concat_dimension] += split_dims[i]; } - } - } - else if (concat_dimension == 2) - { - num_ = bottom[0]->num(); - channels_ = bottom[0]->channels(); - height_ = 0; - width_ = bottom[0]->width(); - for (auto i = 0; i < num_concats_; ++i) { - split_dims[i] = bottom[i]->height(); - height_ += split_dims[i]; - } - - if (this->num_ == bottom[0]->num() && - this->channels_ == bottom[0]->channels() && - this->width_ == bottom[0]->width()) { - this->reshape = false; - } else { - this->reshape = true; - } + this->reshape = (this->shape_[0] != bottom[0]->shape(0)); } - else if (concat_dimension == 3) - { - num_ = bottom[0]->num(); - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = 0; + else{ for (auto i = 0; i < num_concats_; ++i) { - split_dims[i] = bottom[i]->width(); - width_ += split_dims[i]; - } - - if (this->num_ == bottom[0]->num() && - this->channels_ == bottom[0]->channels() && - this->height_ == bottom[0]->height()) { - this->reshape = false; - } else { - this->reshape = true; + split_dims[i] = bottom[i]->shape(concat_dimension); + shape_[concat_dimension] += split_dims[i]; } } - top[0]->Reshape(num_, channels_, height_, width_); + top[0]->Reshape(shape_); } template @@ -301,9 +160,13 @@ void MKLDNNConcatLayer::InitConcatFwd(const vector*>& bottom, memory::data_type usr_dt = memory::data_type::f32; memory::data_type prv_dt = usr_dt; // memory::format mfmt_any = memory::format::any; - memory::format mfmt_nchw = memory::format::nchw; + memory::format mfmt_out; + if(this->shape_.size() == 4) + mfmt_out = memory::format::nchw; + else + mfmt_out = memory::format::ncdhw; - memory::dims output_tz = {num_, channels_, height_, width_}; + memory::dims output_tz = this->shape_; std::vector srcs_mpd; std::vector srcs; fwd_bottom_data.clear(); @@ -345,28 +208,13 @@ void MKLDNNConcatLayer::InitConcatFwd(const vector*>& bottom, fwd_bottom_data.push_back(boost::shared_ptr >()); mem_descr.push_back(boost::shared_ptr>()); - memory::dims input_tz = {0, 0, 0, 0}; - if (concat_dimension == 0) - { - input_tz = {split_dims[i], channels_, height_, width_}; - } - else if (concat_dimension == 1) - { - input_tz = {num_, split_dims[i], height_, width_}; - } - else if (concat_dimension == 2) - { - input_tz = {num_, channels_, split_dims[i], width_}; - } - else if (concat_dimension == 3) - { - input_tz = {num_, channels_, height_, split_dims[i]}; - } + memory::dims input_tz = this->shape_; + input_tz[concat_dimension] = split_dims[i]; - memory::format src_mfmt = mfmt_nchw; + memory::format src_mfmt = mfmt_out; shared_ptr prv_src_mpd; shared_ptr usr_src_mpd( - new memory::primitive_desc({input_tz, usr_dt, mfmt_nchw}, cpu_engine)); + new memory::primitive_desc({input_tz, usr_dt, mfmt_out}, cpu_engine)); if (const_cast(bottom[i]->prv_data()) != NULL) { scale = 1.; @@ -396,7 +244,7 @@ void MKLDNNConcatLayer::InitConcatFwd(const vector*>& bottom, } shared_ptr usr_dst_mpd(new memory::primitive_desc( - {output_tz, usr_dt, mfmt_nchw}, cpu_engine)); + {output_tz, usr_dt, mfmt_out}, cpu_engine)); concatFwd_pd.reset(new concat::primitive_desc(concat_dimension, srcs_mpd)); @@ -414,7 +262,7 @@ void MKLDNNConcatLayer::InitConcatFwd(const vector*>& bottom, fwd_output_memory = fwd_top_data->create_output_memory(); - memory::format base_mfmt = mfmt_nchw; + memory::format base_mfmt = mfmt_out; float base_scale = 1.; this->in_place_ = true; @@ -468,15 +316,20 @@ void MKLDNNConcatLayer::InitConcatBwd(const vector*>& top, engine cpu_engine = CpuEngine::Instance().get_engine(); memory::data_type data_type = memory::data_type::f32; // memory::format mfmt_any = memory::format::any; - memory::format mfmt_nchw = memory::format::nchw; - memory::format diff_dst_mfmt = mfmt_nchw; + memory::format mfmt_out; + if(this->shape_.size() == 4) + mfmt_out = memory::format::nchw; + else + mfmt_out = memory::format::ncdhw; - memory::dims input_tz = {num_, channels_, height_, width_}; + memory::format diff_dst_mfmt = mfmt_out; + + memory::dims input_tz = this->shape_; memory::dims offsets = {0, 0, 0, 0}; shared_ptr prv_diff_dst_mpd; shared_ptr usr_diff_dst_mpd( - new memory::primitive_desc({input_tz, data_type, mfmt_nchw}, + new memory::primitive_desc({input_tz, data_type, mfmt_out}, cpu_engine)); bool top_diff_is_prv = (const_cast(top[0]->prv_diff()) != NULL); @@ -502,26 +355,11 @@ void MKLDNNConcatLayer::InitConcatBwd(const vector*>& top, bwd_bottom_diff.push_back(boost::shared_ptr >()); reorders.push_back(MKLDNNPrimitive()); - memory::dims dims = {0, 0, 0, 0}; - if (concat_dimension == 0) - { - dims = {split_dims[i], channels_, height_, width_}; - } - else if (concat_dimension == 1) - { - dims = {num_, split_dims[i], height_, width_}; - } - else if (concat_dimension == 2) - { - dims = {num_, channels_, split_dims[i], width_}; - } - else if (concat_dimension == 3) - { - dims = {num_, channels_, height_, split_dims[i]}; - } + memory::dims dims = this->shape_; + dims[concat_dimension] = split_dims[i]; shared_ptr usr_diff_src_mpd( - new memory::primitive_desc({dims, data_type, mfmt_nchw}, + new memory::primitive_desc({dims, data_type, mfmt_out}, cpu_engine)); shared_ptr prv_diff_src_mpd( new memory::primitive_desc({dims, data_type, diff_dst_mfmt}, diff --git a/src/caffe/layers/mkldnn_eltwise_layer.cpp b/src/caffe/layers/mkldnn_eltwise_layer.cpp index aed8db34f..7c2df8b63 100644 --- a/src/caffe/layers/mkldnn_eltwise_layer.cpp +++ b/src/caffe/layers/mkldnn_eltwise_layer.cpp @@ -78,15 +78,10 @@ template void MKLDNNEltwiseLayer::Reshape(const vector*>& bottom, const vector*>& top) { VLOG(1) << "MKLDNNEltwiseLayer::Reshape: " << this->layer_param_.name(); - this->reshape = (this->width_ == bottom[0]->width() && - this->height_ == bottom[0]->height() && - this->channels_ == bottom[0]->channels() && - this->num_ == bottom[0]->num()) ? false : true; - - this->width_ = bottom[0]->width(); - this->height_ = bottom[0]->height(); - this->num_ = bottom[0]->num(); - this->channels_ = bottom[0]->channels(); + this->reshape = (this->shape_ == bottom[0]->shape()) ? false : true; + this->shape_ = bottom[0]->shape(); + CHECK_LE(this->shape_.size(), 5) + << "Tensor dimension must be less than 6"; switch (op_) { @@ -128,11 +123,6 @@ template void MKLDNNEltwiseLayer::InitEltwiseFwd(const vector*>& bottom, const vector*>& top) { if (std::is_same::value) NOT_IMPLEMENTED; - - int32_t n = this->num_; - int32_t iw = this->width_; - int32_t ih = this->height_; - int32_t ic = this->channels_; // If we just do simple adding, scale is 1.0 for all inputs we have std::vector scale(num_bottoms_, 1.0); @@ -144,7 +134,15 @@ void MKLDNNEltwiseLayer::InitEltwiseFwd(const vector*>& botto engine cpu_engine = CpuEngine::Instance().get_engine(); memory::data_type mpcsn = memory::data_type::f32; - memory::format mfmt_nchw = memory::format::nchw; + + + memory::format src_mfmt; + auto tensor_size = this->shape_.size(); + if(tensor_size == 4) { + src_mfmt = memory::format::nchw; + } else if(tensor_size == 5) { + src_mfmt = memory::format::ncdhw; + } // ---- Initialize memory descriptors ------------- std::vector prv_dt(num_bottoms_, memory::data_type::f32); @@ -177,10 +175,10 @@ void MKLDNNEltwiseLayer::InitEltwiseFwd(const vector*>& botto for (auto i = 0; i < num_bottoms_; i++) { fwd_bottom_data.push_back(boost::shared_ptr >()); - memory::format bottom_data_mfmt = mfmt_nchw; + memory::format bottom_data_mfmt = src_mfmt; shared_ptr prv_bottom_data_mpd; shared_ptr usr_bottom_data_mpd( - new memory::primitive_desc({{n, ic, ih, iw}, mpcsn, mfmt_nchw}, cpu_engine)); + new memory::primitive_desc({this->shape_, mpcsn, src_mfmt}, cpu_engine)); bool bottom_data_is_prv = (const_cast(bottom[i]->prv_data()) != NULL); if (bottom_data_is_prv) @@ -194,11 +192,11 @@ void MKLDNNEltwiseLayer::InitEltwiseFwd(const vector*>& botto mem_descr->prv_memory_pd()->desc().data.data_type); } prv_bottom_data_mpd.reset(new memory::primitive_desc( - {{n, ic, ih, iw}, bottom_data_dt, bottom_data_mfmt}, cpu_engine)); + {this->shape_, bottom_data_dt, bottom_data_mfmt}, cpu_engine)); } bottom_data_mpd.push_back(memory::primitive_desc( - {{n, ic, ih, iw}, bottom_data_dt, bottom_data_mfmt}, cpu_engine)); + {this->shape_, bottom_data_dt, bottom_data_mfmt}, cpu_engine)); fwd_bottom_data[i].reset(new MKLDNNData( usr_bottom_data_mpd, prv_bottom_data_mpd, bottom[i], this)); @@ -208,13 +206,13 @@ void MKLDNNEltwiseLayer::InitEltwiseFwd(const vector*>& botto } shared_ptr usr_top_data_mpd(new memory::primitive_desc( - {{n, ic, ih, iw}, mpcsn, mfmt_nchw}, cpu_engine)); + {this->shape_, mpcsn, src_mfmt}, cpu_engine)); // ---- Determining engine to use ----------------------- std::string subengines = this->layer_param_.engine(); if (subengines.find("MKLDNN") == std::string::npos || subengines == "MKLDNN") subengines = "MKLDNN:CPU"; - eltwiseFwd_pd.reset(new sum::primitive_desc({{n, ic, ih, iw}, bottom_data_dt, memory::format::any}, scale, bottom_data_mpd)); + eltwiseFwd_pd.reset(new sum::primitive_desc({this->shape_, bottom_data_dt, memory::format::any}, scale, bottom_data_mpd)); CHECK(eltwiseFwd_pd); shared_ptr prv_top_data_mpd(new memory::primitive_desc(eltwiseFwd_pd->dst_primitive_desc())); diff --git a/src/caffe/layers/mkldnn_relu_layer.cpp b/src/caffe/layers/mkldnn_relu_layer.cpp index 4c5f83456..80e10799e 100644 --- a/src/caffe/layers/mkldnn_relu_layer.cpp +++ b/src/caffe/layers/mkldnn_relu_layer.cpp @@ -60,14 +60,12 @@ void MKLDNNReLULayer::Reshape(const vector*>& bottom NeuronLayer::Reshape(bottom, top); - this->reshape = (this->width_ == bottom[0]->width() && - this->height_ == bottom[0]->height() && - this->channels_ == bottom[0]->channels() && - this->num_ == bottom[0]->num()) ? false : true; - this->width_ = bottom[0]->width(); - this->height_ = bottom[0]->height(); - this->num_ = bottom[0]->num(); - this->channels_ = bottom[0]->channels(); + this->reshape = (this->shape_ == bottom[0]->shape()) ? false : true; + + this->shape_ = bottom[0]->shape(); + + CHECK_LE(this->shape_.size(), 5) + << "Tensor dimension must be less than 6."; } @@ -76,10 +74,6 @@ void MKLDNNReLULayer::InitReLUFwd(const vector*>& bottom, con { if (std::is_same::value) NOT_IMPLEMENTED; auto propagation = this->phase_ == TEST ? prop_kind::forward_scoring : prop_kind::forward_training; - int32_t n = this->num_; - int32_t iw = this->width_; - int32_t ih = this->height_; - int32_t ic = this->channels_; Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); bool bottom_data_is_prv = (const_cast(bottom[0]->prv_data()) != NULL); @@ -92,7 +86,15 @@ void MKLDNNReLULayer::InitReLUFwd(const vector*>& bottom, con shared_ptr usr_data_mpd(NULL), prv_data_mpd(NULL), top_data_mpd(NULL); memory::data_type src_dt = memory::data_type::f32; memory::data_type top_dt = memory::data_type::f32; - memory::format src_mfmt = memory::format::nchw; + + memory::format src_mfmt; + auto tensor_size = this->shape_.size(); + if(tensor_size == 4) { + src_mfmt = memory::format::nchw; + } else if(tensor_size == 5) { + src_mfmt = memory::format::ncdhw; + } + //bottom_data_is_prv = false; std::vector scale; if (bottom_data_is_prv) { @@ -105,13 +107,13 @@ void MKLDNNReLULayer::InitReLUFwd(const vector*>& bottom, con src_dt = static_cast(mem_descr->prv_memory_pd()->desc().data.data_type); src_mfmt = static_cast(mem_descr->prv_memory_pd()->desc().data.format); } else { - bottom_data_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw)); + bottom_data_md.reset(new memory::desc(this->shape_, mpcsn, src_mfmt)); usr_data_mpd.reset(new memory::primitive_desc(*bottom_data_md, cpu_engine)); prv_data_mpd.reset(new memory::primitive_desc(*bottom_data_md, cpu_engine)); scale.push_back(1.); } top_dt = src_dt; - top_data_mpd.reset(new memory::primitive_desc({{n,ic,ih,iw}, top_dt, src_mfmt}, cpu_engine)); + top_data_mpd.reset(new memory::primitive_desc({this->shape_, top_dt, src_mfmt}, cpu_engine)); // ---- Initialize relu primitive descriptor ------------- //relu_forward::desc reluFwd_desc(propagation, *bottom_data_md, negative_slope); @@ -195,11 +197,6 @@ void MKLDNNReLULayer::InitReLUBwd(const vector*>& top { if (std::is_same::value) NOT_IMPLEMENTED; - int32_t n = this->num_; - int32_t iw = this->width_; - int32_t ih = this->height_; - int32_t ic = this->channels_; - Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); bool top_diff_is_prv = top[0]->prv_diff() != NULL; bool inplace = (bottom[0] == top[0]); @@ -226,7 +223,7 @@ void MKLDNNReLULayer::InitReLUBwd(const vector*>& top usr_data_mpd = mem_descr->usr_memory_pd(); prv_data_mpd = mem_descr->prv_memory_pd(); } else { - bottom_data_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw)); + bottom_data_md.reset(new memory::desc(this->shape_, mpcsn, this->shape_.size() == 4 ? memory::format::nchw : memory::format::ncdhw)); usr_data_mpd.reset(new memory::primitive_desc(*bottom_data_md, cpu_engine)); } @@ -276,7 +273,7 @@ void MKLDNNReLULayer::InitReLUBwd(const vector*>& top top[0]->set_prv_diff_descriptor(NULL); } - top_diff_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw)); + top_diff_md.reset(new memory::desc(this->shape_, mpcsn, this->shape_.size() == 4 ? memory::format::nchw : memory::format::ncdhw)); usr_diff_mpd.reset(new memory::primitive_desc(*top_diff_md, cpu_engine)); } diff --git a/src/caffe/layers/mvn_layer.cpp b/src/caffe/layers/mvn_layer.cpp index 09542b105..3f9f4804d 100644 --- a/src/caffe/layers/mvn_layer.cpp +++ b/src/caffe/layers/mvn_layer.cpp @@ -45,19 +45,21 @@ namespace caffe { template void MVNLayer::Reshape(const vector*>& bottom, const vector*>& top) { - top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - mean_.Reshape(bottom[0]->num(), bottom[0]->channels(), - 1, 1); - variance_.Reshape(bottom[0]->num(), bottom[0]->channels(), - 1, 1); - temp_.Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); + top[0]->Reshape(bottom[0]->shape()); + vector temp_shape = bottom[0]->shape(); + for (int i = 2; i < temp_shape.size(); i++) + temp_shape[i] = 1; + mean_.Reshape(temp_shape); + variance_.Reshape(temp_shape); + temp_.Reshape(bottom[0]->shape()); + + vector shape = bottom[0]->shape(); + shape[0] = 1; if ( this->layer_param_.mvn_param().across_channels() ) { - sum_multiplier_.Reshape(1, bottom[0]->channels(), bottom[0]->height(), - bottom[0]->width()); + sum_multiplier_.Reshape(shape); } else { - sum_multiplier_.Reshape(1, 1, bottom[0]->height(), bottom[0]->width()); + shape[1] = 1; + sum_multiplier_.Reshape(shape); } Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); caffe_set(sum_multiplier_.count(), Dtype(1), multiplier_data); @@ -71,9 +73,9 @@ void MVNLayer::Forward_cpu(const vector*>& bottom, Dtype* top_data = top[0]->mutable_cpu_data(); int num; if (this->layer_param_.mvn_param().across_channels()) - num = bottom[0]->num(); + num = bottom[0]->shape(0); else - num = bottom[0]->num() * bottom[0]->channels(); + num = bottom[0]->shape(0) * bottom[0]->shape(1); int dim = bottom[0]->count() / num; @@ -118,9 +120,9 @@ void MVNLayer::Backward_cpu(const vector*>& top, int num; if (this->layer_param_.mvn_param().across_channels()) - num = bottom[0]->num(); + num = bottom[0]->shape(0); else - num = bottom[0]->num() * bottom[0]->channels(); + num = bottom[0]->shape(0) * bottom[0]->shape(1); int dim = bottom[0]->count() / num; diff --git a/src/caffe/layers/mvn_layer.cu b/src/caffe/layers/mvn_layer.cu index 739293be0..1df728d08 100644 --- a/src/caffe/layers/mvn_layer.cu +++ b/src/caffe/layers/mvn_layer.cu @@ -12,9 +12,9 @@ void MVNLayer::Forward_gpu(const vector*>& bottom, Dtype* top_data = top[0]->mutable_gpu_data(); int num; if (this->layer_param_.mvn_param().across_channels()) - num = bottom[0]->num(); + num = bottom[0]->shape(0); else - num = bottom[0]->num() * bottom[0]->channels(); + num = bottom[0]->shape(0) * bottom[0]->shape(1); int dim = bottom[0]->count() / num; @@ -60,9 +60,9 @@ void MVNLayer::Backward_gpu(const vector*>& top, int num; if (this->layer_param_.mvn_param().across_channels()) - num = bottom[0]->num(); + num = bottom[0]->shape(0); else - num = bottom[0]->num() * bottom[0]->channels(); + num = bottom[0]->shape(0) * bottom[0]->shape(1); int dim = bottom[0]->count() / num;