Skip to content

Commit

Permalink
Fix ReLU issue per
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuck Cho committed Dec 22, 2016
1 parent 0e7e225 commit 9b635f2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
8 changes: 8 additions & 0 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ inline void setTensorNdDesc(cudnnTensorDescriptor_t* desc,
std::vector<int> stride) {
CHECK_EQ(shape.size(), stride.size()) <<
"Dimensions of shape and stride don't match !";
// fill shape with 1 to create tensors with at least 4 dimensions
// to prevent CUDNN_STATUS_BAD_PARAM error in CUDNN v4
// TODO(christian.payer@gmx.net): check CUDNN doc, probably fixed
// in newer versions
for (int i = shape.size(); i < 4; ++i) {
shape.push_back(1);
stride.push_back(1);
}
CUDNN_CHECK(cudnnSetTensorNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data(), stride.data()));
}
Expand Down
12 changes: 4 additions & 8 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
ReLULayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
cudnn::createActivationDescriptor<Dtype>(&activ_desc_, CUDNN_ACTIVATION_RELU);
handles_setup_ = true;
}
Expand All @@ -21,12 +21,8 @@ template <typename Dtype>
void CuDNNReLULayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
ReLULayer<Dtype>::Reshape(bottom, top);
const int N = bottom[0]->num();
const int K = bottom[0]->channels();
const int H = bottom[0]->height();
const int W = bottom[0]->width();
cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
cudnn::setTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
cudnn::setTensorNdDesc<Dtype>(&bottom_desc_, bottom[0]->shape());
cudnn::setTensorNdDesc<Dtype>(&top_desc_, bottom[0]->shape());
}

template <typename Dtype>
Expand Down

0 comments on commit 9b635f2

Please sign in to comment.