Skip to content
This repository has been archived by the owner on Aug 5, 2022. It is now read-only.

Add 3dUnet Support #242

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/caffe/layers/mkldnn_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {

shared_ptr<primitive> input_primitive, bwd_top_diff_primitive;

int32_t num_, width_, height_, channels_;
vector<int> shape_;
Dtype eps_, moving_average_fraction_;
bool use_weight_bias_, bias_term_, use_global_stats_;
int num_stats_batches_;
Expand Down Expand Up @@ -402,7 +402,7 @@ class MKLDNNReLULayer : public MKLDNNLayer<Dtype> , public NeuronLayer<Dtype> {
, reluFwd_pd(), reluBwd_pd()
, fwd_top_data_memory(), bwd_bottom_diff_memory()
, fwd_bottom_data_primitive(), bwd_top_diff_primitive()
, num_(0), width_(0), height_(0), channels_(0)
, shape_(0)
{
PERFORMANCE_EVENT_ID_RESET(perf_id_fw_);
PERFORMANCE_EVENT_ID_RESET(perf_id_bw_);
Expand Down Expand Up @@ -431,7 +431,7 @@ class MKLDNNReLULayer : public MKLDNNLayer<Dtype> , public NeuronLayer<Dtype> {
MKLDNNPrimitive<Dtype> reluFwd, reluBwd;
shared_ptr<memory> fwd_top_data_memory, bwd_bottom_diff_memory;
shared_ptr<primitive> fwd_bottom_data_primitive, bwd_top_diff_primitive, bwd_bottom_data_primitive;
int32_t num_, width_, height_, channels_;
vector<int> shape_;

PERFORMANCE_EVENT_ID_DECL(perf_id_fw_);
PERFORMANCE_EVENT_ID_DECL(perf_id_bw_);
Expand Down Expand Up @@ -480,7 +480,8 @@ class MKLDNNConcatLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
vector<int> split_dims;
bool in_place_;

int32_t num_, width_, height_, channels_, num_concats_;
int32_t num_concats_;
vector<int> shape_;
int concat_dimension;

PERFORMANCE_EVENT_ID_DECL(perf_id_fw_);
Expand Down Expand Up @@ -537,7 +538,7 @@ class MKLDNNEltwiseLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
, eltwiseFwd_pd()
, fwd_top_data_memory()
, fwd_bottom_data_primitives_()
, num_(0), width_(0), height_(0), channels_(0)
, shape_(0)
, num_bottoms_(0)
{
PERFORMANCE_EVENT_ID_RESET(perf_id_fw_);
Expand Down Expand Up @@ -573,7 +574,7 @@ class MKLDNNEltwiseLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
EltwiseParameter_EltwiseOp op_;
vector<Dtype> coeffs_;
Blob<int> max_idx_;
int32_t num_, width_, height_, channels_;
vector<int> shape_;
int32_t num_bottoms_;
bool stable_prod_grad_;

Expand Down
103 changes: 55 additions & 48 deletions src/caffe/layers/mkldnn_batch_norm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom

Layer<Dtype>::LayerSetUp(bottom, top);

channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
num_ = bottom[0]->num();
shape_ = bottom[0]->shape();

const int channels = shape_[1];

eps_ = this->layer_param_.batch_norm_param().eps();
use_weight_bias_ = this->layer_param_.batch_norm_param().use_weight_bias();
Expand All @@ -77,12 +76,12 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
if (this->layer_param_.batch_norm_param().has_use_global_stats())
use_global_stats_ = this->layer_param_.batch_norm_param().use_global_stats();

InitStatsBatchVars(num_);
InitStatsBatchVars(shape_[0]);

this->blobs_.resize(3 + (use_weight_bias_ ? 1:0) + (use_weight_bias_ && bias_term_ ? 1:0));

vector<int> sz;
sz.push_back(channels_);
sz.push_back(channels);
this->blobs_[0].reset(new Blob<Dtype>(sz));
this->blobs_[1].reset(new Blob<Dtype>(sz));
sz[0]=1;
Expand All @@ -96,7 +95,7 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
//Optimization: use the temp blob to combine the scale and shift together. Avoid the additional copies.
// Initialize scale and shift combination blob
vector<int> scaleshift_blob_shape(1);
scaleshift_blob_shape[0] = 2*channels_;
scaleshift_blob_shape[0] = 2*channels;
scaleshift_blob_.reset(new Blob<Dtype>(scaleshift_blob_shape));
//Should initialize the scaleshift_blob_ buffer to 0, because when bias_term_ == false, need to pass zero bias to MKLDNN
caffe_set(scaleshift_blob_shape[0], static_cast<Dtype>(0),
Expand All @@ -111,8 +110,8 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
if (use_weight_bias_) {
// Initialize scale and shift
vector<int> scaleshift_shape(1);
scaleshift_shape[0] = channels_;
VLOG(1) << "MKLDNNBatchNormLayer<Dtype>::LayerSetUp: channels_ = " << channels_;
scaleshift_shape[0] = channels;
VLOG(1) << "MKLDNNBatchNormLayer<Dtype>::LayerSetUp: channels = " << channels;

this->blobs_[3].reset(new Blob<Dtype>(scaleshift_shape));
this->blobs_[3]->set_cpu_data(scaleshift_blob_->mutable_cpu_data());
Expand All @@ -128,8 +127,8 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom

if (bias_term_) {
this->blobs_[4].reset(new Blob<Dtype>(scaleshift_shape));
this->blobs_[4]->set_cpu_data(scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels_));
this->blobs_[4]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff() + scaleshift_blob_->offset(channels_));
this->blobs_[4]->set_cpu_data(scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels));
this->blobs_[4]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff() + scaleshift_blob_->offset(channels));
FillerParameter bias_filler_param(this->layer_param_.batch_norm_param().bias_filler());
if (!this->layer_param_.batch_norm_param().has_bias_filler()) {
bias_filler_param.set_type("constant");
Expand Down Expand Up @@ -161,17 +160,9 @@ void MKLDNNBatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom
{
VLOG(1) << "MKLDNNBatchNormLayer<Dtype>::Reshape: " << this->layer_param_.name();

this->reshape = (this->width_ == bottom[0]->width() &&
this->height_ == bottom[0]->height() &&
this->channels_ == bottom[0]->channels() &&
this->num_ == bottom[0]->num()) ? false : true;

this->width_ = bottom[0]->width();
this->height_ = bottom[0]->height();
this->num_ = bottom[0]->num();
this->channels_ = bottom[0]->channels();
this->reshape = (this->shape_ == bottom[0]->shape()) ? false : true;

InitStatsBatchVars(this->num_);
InitStatsBatchVars(this->shape_[0]);

//Fix: should reshape the top blob with the real size of bottom blob
//top[0]->Reshape(this->num_, this->channels_, this->height_, this->width_);
Expand All @@ -194,10 +185,15 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNorm(const vector<Blob<Dtype>*>& bott
if (use_weight_bias_) flags |= use_scale_shift;
if (use_global_stats_) flags |= use_global_stats;

int32_t n = this->num_;
int32_t iw = this->width_;
int32_t ih = this->height_;
int32_t ic = this->channels_;
memory::format src_mfmt;
auto tensor_size = this->shape_.size();
if(tensor_size == 4) {
src_mfmt = memory::format::nchw;
} else if(tensor_size == 5) {
src_mfmt = memory::format::ncdhw;
}

const int channels = this->shape_[1];

bool bottom_data_is_prv = (const_cast<Dtype*>(bottom[0]->prv_data()) != NULL);

Expand All @@ -216,13 +212,13 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNorm(const vector<Blob<Dtype>*>& bott
usr_mpd = mem_descr->usr_memory_pd();
prv_mpd = mem_descr->prv_memory_pd();
} else {
input_md.reset(new memory::desc({{n, ic, ih, iw}}, mpcsn, memory::format::nchw)); //MKLDNN batch norm only support 4D memory descriptor!
input_md.reset(new memory::desc({this->shape_}, mpcsn, src_mfmt));
usr_mpd.reset(new memory::primitive_desc(*input_md, cpu_engine));
}
output_md = input_md;
input_stats_md.reset(new memory::desc(*input_md));
CHECK(input_stats_md->data.ndims > 0 &&
input_stats_md->data.dims[0] == this->num_);
input_stats_md->data.dims[0] == this->shape_[0]);
input_stats_md->data.dims[0] = stats_batch_size_;

// ---- Initialize BatchNorm primitive descriptor -------------
Expand Down Expand Up @@ -262,7 +258,7 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNorm(const vector<Blob<Dtype>*>& bott
if (use_weight_bias_) {
//For test in train, memory address of blobs_[3] and blobs_[4] will be changed when share data from train net. If the address
// of blobs_[3] and blobs_[4] are continued, we will use them immediately, otherwise we will copy them to scaleshift_blob_ in Forward.
if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels_)) == this->blobs_[4]->mutable_cpu_data()){
if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels)) == this->blobs_[4]->mutable_cpu_data()){
scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->blobs_[3]->mutable_cpu_data()));
}else {
scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_blob_->mutable_cpu_data()));
Expand Down Expand Up @@ -309,8 +305,8 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNorm(const vector<Blob<Dtype>*>& bott
LOG(INFO) << "MKLDNN batch norm only support 4D memory descriptor! Use 4D for calculation and reshape to 2D for output!";
#endif
vector<int> top_shape;
top_shape.push_back(bottom[0]->num());
top_shape.push_back(bottom[0]->channels());
top_shape.push_back(bottom[0]->shape(0));
top_shape.push_back(bottom[0]->shape(1));
top[0]->Reshape(top_shape);
}
}
Expand All @@ -319,12 +315,15 @@ template <typename Dtype>
template <bool diff>
shared_ptr<memory> MKLDNNBatchNormLayer<Dtype>::GetStatsBatchMemory(
shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_mem, int idx) {
long data_offset =
idx * stats_batch_size_ * this->channels_ * this->width_ * this->height_;
int length = this->shape_[1];
for(int i=2;i<this->shape_.size();i++)
length *= this->shape_[i];

long data_offset = idx * stats_batch_size_ * length;
engine cpu_engine = CpuEngine::Instance().get_engine();
shared_ptr<memory::desc> stats_md = mkldnn_mem->get_memory_desc();
CHECK(stats_md->data.ndims > 0 &&
stats_md->data.dims[0] == this->num_);
stats_md->data.dims[0] == this->shape_[0]);
stats_md->data.dims[0] = stats_batch_size_;
shared_ptr<memory::primitive_desc> stats_mpd(
new memory::primitive_desc(*stats_md, cpu_engine));
Expand All @@ -338,6 +337,8 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNormFwdPrimitive(int idx) {
input_stats[idx] = GetStatsBatchMemory<false>(fwd_bottom_data, idx);
output_stats[idx] = GetStatsBatchMemory<false>(fwd_top_data, idx);

const int channels = this->shape_[1];

// ---- Create BatchNorm --------------------
if (this->phase_ == TEST && !use_global_stats_) {
if (use_weight_bias_) {
Expand All @@ -353,9 +354,9 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNormFwdPrimitive(int idx) {
variance_memory[idx].reset(new memory(BatchNormFwd_pd->variance_primitive_desc()));

if (use_global_stats_) {
caffe_copy<Dtype>(this->channels_, this->blobs_[0]->cpu_data(),
caffe_copy<Dtype>(channels, this->blobs_[0]->cpu_data(),
static_cast<Dtype *>(mean_memory[idx]->get_data_handle()));
caffe_copy<Dtype>(this->channels_, this->blobs_[1]->cpu_data(),
caffe_copy<Dtype>(channels, this->blobs_[1]->cpu_data(),
static_cast<Dtype *>(variance_memory[idx]->get_data_handle()));
if (use_weight_bias_) {
BatchNormFwd[idx].reset(new batch_normalization_forward(*BatchNormFwd_pd,
Expand Down Expand Up @@ -398,9 +399,11 @@ void MKLDNNBatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom
// update top that head at prv
fwd_top_data->sync_before_write();

if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels_)) != this->blobs_[4]->mutable_cpu_data()){
caffe_copy(channels_, this->blobs_[3]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data());
caffe_copy(channels_, this->blobs_[4]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels_));
const int channels = this->shape_[1];

if((this->blobs_[3]->mutable_cpu_data() + this->blobs_[3]->offset(channels)) != this->blobs_[4]->mutable_cpu_data()){
caffe_copy(channels, this->blobs_[3]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data());
caffe_copy(channels, this->blobs_[4]->cpu_data(), this->scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels));
}

for (int stats_batch_idx = 0; stats_batch_idx < num_stats_batches_; stats_batch_idx++) {
Expand Down Expand Up @@ -429,11 +432,11 @@ void MKLDNNBatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom
Dtype *variance_buffer_ = (Dtype *)(variance_memory[stats_batch_idx]->get_data_handle());
this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
this->blobs_[2]->mutable_cpu_data()[0] += 1;
caffe_cpu_axpby<Dtype>(this->channels_, Dtype(1), mean_buffer_,
caffe_cpu_axpby<Dtype>(channels, Dtype(1), mean_buffer_,
moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
int m = bottom[0]->count()/num_stats_batches_/channels_;
int m = bottom[0]->count()/num_stats_batches_/channels;
Dtype bias_correction_factor = m > 1 ? Dtype(m)/(m-1) : 1;
caffe_cpu_axpby<Dtype>(this->channels_, bias_correction_factor,
caffe_cpu_axpby<Dtype>(channels, bias_correction_factor,
variance_buffer_, moving_average_fraction_,
this->blobs_[1]->mutable_cpu_data());
}
Expand All @@ -450,10 +453,14 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNormBwd(
{
if (std::is_same<Dtype, double>::value) NOT_IMPLEMENTED;

int32_t n = this->num_;
int32_t w = this->width_;
int32_t h = this->height_;
int32_t c = this->channels_;
memory::format src_mfmt;
auto tensor_size = this->shape_.size();
if(tensor_size == 4) {
src_mfmt = memory::format::nchw;
} else if(tensor_size == 5) {
src_mfmt = memory::format::ncdhw;
}


unsigned flags = 0;
if (use_weight_bias_) flags |= use_scale_shift;
Expand All @@ -475,16 +482,16 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNormBwd(
usr_diff_mpd = mem_descr->usr_memory_pd();
prv_diff_mpd = mem_descr->prv_memory_pd();
} else {
top_diff_md.reset(new memory::desc({{n, c, h, w}}, mpcsn, memory::format::nchw)); //MKLDNN batch norm only support 4D memory descriptor!
top_diff_md.reset(new memory::desc({this->shape_}, mpcsn, src_mfmt));
usr_diff_mpd.reset(new memory::primitive_desc(*top_diff_md, cpu_engine));
}
top_diff_stats_md.reset(new memory::desc(*top_diff_md));
CHECK(top_diff_stats_md->data.ndims > 0 &&
top_diff_stats_md->data.dims[0] == this->num_);
top_diff_stats_md->data.dims[0] == this->shape_[0]);
top_diff_stats_md->data.dims[0] = stats_batch_size_;
output_stats_md.reset(new memory::desc(output_memory->get_primitive_desc().desc()));
CHECK(output_stats_md->data.ndims > 0 &&
output_stats_md->data.dims[0] == this->num_);
output_stats_md->data.dims[0] == this->shape_[0]);
output_stats_md->data.dims[0] = stats_batch_size_;

// ---- Initialize bnrm primitive descriptor -------------
Expand Down