forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
y = x[:, :, crop_size:-crop_size, crop_size:-crop_size] を実現する
- Loading branch information
Showing
6 changed files
with
390 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#ifndef CAFFE_CROP_CENTER_LAYER_HPP_ | ||
#define CAFFE_CROP_CENTER_LAYER_HPP_ | ||
|
||
#include <utility> | ||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Takes a Blob and crop. | ||
* | ||
* TODO(dox): thorough documentation for Forward, Backward, and proto params. | ||
*/ | ||
|
||
template <typename Dtype> | ||
class CropCenterLayer : public Layer<Dtype> { | ||
public: | ||
explicit CropCenterLayer(const LayerParameter& param) | ||
: Layer<Dtype>(param) {} | ||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
|
||
virtual inline const char* type() const { return "CropCenter"; } | ||
virtual inline int ExactNumBottomBlobs() const { return 1; } | ||
virtual inline int ExactNumTopBlobs() const { return 1; } | ||
|
||
protected: | ||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
|
||
Blob<int> crop_sizes_; | ||
Blob<int> src_strides_; | ||
Blob<int> dest_strides_; | ||
|
||
private: | ||
// Recursive copy function. | ||
void crop_copy(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top, | ||
const int* offsets, | ||
vector<int> indices, | ||
int cur_dim, | ||
const Dtype* src_data, | ||
Dtype* dest_data, | ||
bool is_forward); | ||
|
||
// Recursive copy function: this is similar to crop_copy() but loops over all | ||
// but the last two dimensions to allow for ND cropping while still relying on | ||
// a CUDA kernel for the innermost two dimensions for performance reasons. An | ||
// alterantive implementation could rely on the kernel more by passing | ||
// offsets, but this is problematic because of its variable length. | ||
// Since in the standard (N,C,W,H) case N,C are usually not cropped a speedup | ||
// could be achieved by not looping the application of the copy_kernel around | ||
// these dimensions. | ||
void crop_copy_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top, | ||
const vector<int>& offsets, | ||
vector<int> indices, | ||
int cur_dim, | ||
const Dtype* src_data, | ||
Dtype* dest_data, | ||
bool is_forward); | ||
}; | ||
} // namespace caffe | ||
|
||
#endif // CAFFE_CROP_CENTER_LAYER_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#include <algorithm> | ||
#include <functional> | ||
#include <map> | ||
#include <set> | ||
#include <vector> | ||
|
||
|
||
#include "caffe/layer.hpp" | ||
#include "caffe/layers/crop_center_layer.hpp" | ||
#include "caffe/net.hpp" | ||
|
||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
// LayerSetup() handles the number of dimensions; Reshape() handles the sizes. | ||
// bottom[0] supplies the data | ||
const CropCenterParameter& param = this->layer_param_.crop_center_param(); | ||
CHECK_EQ(bottom.size(), 1) << "Wrong number of bottom blobs."; | ||
int input_dim = bottom[0]->num_axes(); | ||
CHECK_EQ(param.crop_size().size(), input_dim) << "Wrong param.crop_size."; | ||
for(int i = 0; i < input_dim; i++) { | ||
int crop_size = param.crop_size(i); | ||
CHECK_LE(crop_size * 2 + 1, bottom[0]->shape(i)) << "crop size bigger than input shape"; | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
const CropCenterParameter& param = this->layer_param_.crop_center_param(); | ||
int input_dim = bottom[0]->num_axes(); | ||
|
||
// Initialize crop_sizes_ to 0 and the new shape to the current shape of the data. | ||
vector<int> new_shape(bottom[0]->shape()); | ||
vector<int> offsets_shape(1, input_dim); | ||
crop_sizes_.Reshape(offsets_shape); | ||
int* crop_size_data = crop_sizes_.mutable_cpu_data(); | ||
|
||
// Determine crop offsets and the new shape post-crop. | ||
for (int i = 0; i < input_dim; ++i) { | ||
int crop_size = param.crop_size(i); | ||
int new_size = bottom[0]->shape(i) - crop_size * 2; | ||
new_shape[i] = new_size; | ||
crop_size_data[i] = crop_size; | ||
} | ||
top[0]->Reshape(new_shape); | ||
// Compute strides | ||
src_strides_.Reshape(offsets_shape); | ||
dest_strides_.Reshape(offsets_shape); | ||
for (int i = 0; i < input_dim; ++i) { | ||
src_strides_.mutable_cpu_data()[i] = bottom[0]->count(i + 1, input_dim); | ||
dest_strides_.mutable_cpu_data()[i] = top[0]->count(i + 1, input_dim); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top, | ||
const int* crop_sizes, | ||
vector<int> indices, | ||
int cur_dim, | ||
const Dtype* src_data, | ||
Dtype* dest_data, | ||
bool is_forward) { | ||
int crop_size = crop_sizes[cur_dim]; | ||
if (cur_dim + 1 < bottom[0]->num_axes()) { | ||
// We are not yet at the final dimension, call copy recursively | ||
for (int i = crop_size; i < bottom[0]->shape(cur_dim) - crop_size; ++i) { | ||
indices[cur_dim] = i; | ||
crop_copy(bottom, top, crop_sizes, indices, cur_dim+1, | ||
src_data, dest_data, is_forward); | ||
} | ||
} else { | ||
std::vector<int> ind_red(cur_dim + 1, 0); | ||
std::vector<int> ind_off(cur_dim + 1, 0); | ||
for (int j = 0; j < cur_dim; ++j) { | ||
ind_red[j] = indices[j] - crop_sizes[j]; | ||
ind_off[j] = indices[j]; | ||
} | ||
ind_red[cur_dim] = 0; | ||
ind_off[cur_dim] = crop_sizes[cur_dim]; | ||
// do the copy | ||
int N = top[0]->shape(cur_dim); | ||
if (is_forward) { | ||
caffe_copy(N, | ||
src_data + bottom[0]->offset(ind_off), | ||
dest_data + top[0]->offset(ind_red)); | ||
} else { | ||
// in the backwards pass the src_data is top_diff | ||
// and the dest_data is bottom_diff | ||
caffe_copy(N, | ||
src_data + top[0]->offset(ind_red), | ||
dest_data + bottom[0]->offset(ind_off)); | ||
} | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
std::vector<int> indices(top[0]->num_axes(), 0); | ||
const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
Dtype* top_data = top[0]->mutable_cpu_data(); | ||
crop_copy(bottom, top, crop_sizes_.cpu_data(), indices, 0, bottom_data, top_data, | ||
true); | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
const Dtype* top_diff = top[0]->cpu_diff(); | ||
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); | ||
|
||
if (propagate_down[0]) { | ||
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff); | ||
std::vector<int> indices(top[0]->num_axes(), 0); | ||
crop_copy(bottom, top, crop_sizes_.cpu_data(), indices, 0, top_diff, | ||
bottom_diff, false); | ||
} | ||
} | ||
|
||
#ifdef CPU_ONLY | ||
STUB_GPU(CropCenterLayer); | ||
#endif | ||
|
||
INSTANTIATE_CLASS(CropCenterLayer); | ||
REGISTER_LAYER_CLASS(CropCenter); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#include <vector> | ||
|
||
#include "caffe/layers/crop_center_layer.hpp" | ||
|
||
namespace caffe { | ||
|
||
__device__ static int compute_uncropped_index( | ||
int index, | ||
const int ndims, | ||
const int* src_strides, | ||
const int* dest_strides, | ||
const int* crop_sizes) { | ||
int dest_index = index; | ||
int src_index = 0; | ||
for (int i = 0; i < ndims; ++i) { | ||
int coord = dest_index / dest_strides[i]; | ||
dest_index -= coord * dest_strides[i]; | ||
src_index += src_strides[i] * (coord + crop_sizes[i]); | ||
} | ||
return src_index; | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void crop_center_kernel_forward(const int nthreads, | ||
const int ndims, | ||
const int* src_strides, | ||
const int* dest_strides, | ||
const int* crop_sizes, | ||
const Dtype* src, Dtype* dest) { | ||
CUDA_KERNEL_LOOP(index, nthreads) { | ||
int src_index = compute_uncropped_index( | ||
index, ndims, src_strides, dest_strides, crop_sizes); | ||
dest[index] = src[src_index]; | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void crop_center_kernel_backward(const int nthreads, | ||
const int ndims, | ||
const int* src_strides, | ||
const int* dest_strides, | ||
const int* crop_sizes, | ||
Dtype* src, const Dtype* dest) { | ||
CUDA_KERNEL_LOOP(index, nthreads) { | ||
int src_index = compute_uncropped_index( | ||
index, ndims, src_strides, dest_strides, crop_sizes); | ||
src[src_index] = dest[index]; | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
Dtype* top_data = top[0]->mutable_gpu_data(); | ||
int n = top[0]->count(); | ||
// NOLINT_NEXT_LINE(whitespace/operators) | ||
crop_center_kernel_forward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n, | ||
bottom[0]->num_axes(), | ||
src_strides_.gpu_data(), | ||
dest_strides_.gpu_data(), | ||
crop_sizes_.gpu_data(), | ||
bottom_data, top_data); | ||
} | ||
|
||
template <typename Dtype> | ||
void CropCenterLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | ||
const Dtype* top_diff = top[0]->gpu_diff(); | ||
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); | ||
int n = top[0]->count(); | ||
|
||
if (propagate_down[0]) { | ||
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff); | ||
// NOLINT_NEXT_LINE(whitespace/operators) | ||
crop_center_kernel_backward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n, | ||
bottom[0]->num_axes(), | ||
src_strides_.gpu_data(), | ||
dest_strides_.gpu_data(), | ||
crop_sizes_.gpu_data(), | ||
bottom_diff, top_diff); | ||
} | ||
} | ||
|
||
INSTANTIATE_LAYER_GPU_FUNCS(CropCenterLayer); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.