Skip to content

Commit

Permalink
CropCenterレイヤー追加
Browse files Browse the repository at this point in the history
y = x[:, :, crop_size:-crop_size, crop_size:-crop_size] を実現する
  • Loading branch information
lltcggie committed Oct 23, 2018
1 parent 6935fa3 commit 98f9aea
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 2 deletions.
77 changes: 77 additions & 0 deletions include/caffe/layers/crop_center_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef CAFFE_CROP_CENTER_LAYER_HPP_
#define CAFFE_CROP_CENTER_LAYER_HPP_

#include <utility>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Takes a Blob and crop.
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
*/

template <typename Dtype>
class CropCenterLayer : public Layer<Dtype> {
public:
explicit CropCenterLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "CropCenter"; }
virtual inline int ExactNumBottomBlobs() const { return 1; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

Blob<int> crop_sizes_;
Blob<int> src_strides_;
Blob<int> dest_strides_;

private:
// Recursive copy function.
void crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const int* offsets,
vector<int> indices,
int cur_dim,
const Dtype* src_data,
Dtype* dest_data,
bool is_forward);

// Recursive copy function: this is similar to crop_copy() but loops over all
// but the last two dimensions to allow for ND cropping while still relying on
// a CUDA kernel for the innermost two dimensions for performance reasons. An
// alterantive implementation could rely on the kernel more by passing
// offsets, but this is problematic because of its variable length.
// Since in the standard (N,C,W,H) case N,C are usually not cropped a speedup
// could be achieved by not looping the application of the copy_kernel around
// these dimensions.
void crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim,
const Dtype* src_data,
Dtype* dest_data,
bool is_forward);
};
} // namespace caffe

#endif // CAFFE_CROP_CENTER_LAYER_HPP_
132 changes: 132 additions & 0 deletions src/caffe/layers/crop_center_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <vector>


#include "caffe/layer.hpp"
#include "caffe/layers/crop_center_layer.hpp"
#include "caffe/net.hpp"


namespace caffe {

template <typename Dtype>
void CropCenterLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// LayerSetup() handles the number of dimensions; Reshape() handles the sizes.
// bottom[0] supplies the data
const CropCenterParameter& param = this->layer_param_.crop_center_param();
CHECK_EQ(bottom.size(), 1) << "Wrong number of bottom blobs.";
int input_dim = bottom[0]->num_axes();
CHECK_EQ(param.crop_size().size(), input_dim) << "Wrong param.crop_size.";
for(int i = 0; i < input_dim; i++) {
int crop_size = param.crop_size(i);
CHECK_LE(crop_size * 2 + 1, bottom[0]->shape(i)) << "crop size bigger than input shape";
}
}

template <typename Dtype>
void CropCenterLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const CropCenterParameter& param = this->layer_param_.crop_center_param();
int input_dim = bottom[0]->num_axes();

// Initialize crop_sizes_ to 0 and the new shape to the current shape of the data.
vector<int> new_shape(bottom[0]->shape());
vector<int> offsets_shape(1, input_dim);
crop_sizes_.Reshape(offsets_shape);
int* crop_size_data = crop_sizes_.mutable_cpu_data();

// Determine crop offsets and the new shape post-crop.
for (int i = 0; i < input_dim; ++i) {
int crop_size = param.crop_size(i);
int new_size = bottom[0]->shape(i) - crop_size * 2;
new_shape[i] = new_size;
crop_size_data[i] = crop_size;
}
top[0]->Reshape(new_shape);
// Compute strides
src_strides_.Reshape(offsets_shape);
dest_strides_.Reshape(offsets_shape);
for (int i = 0; i < input_dim; ++i) {
src_strides_.mutable_cpu_data()[i] = bottom[0]->count(i + 1, input_dim);
dest_strides_.mutable_cpu_data()[i] = top[0]->count(i + 1, input_dim);
}
}

template <typename Dtype>
void CropCenterLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const int* crop_sizes,
vector<int> indices,
int cur_dim,
const Dtype* src_data,
Dtype* dest_data,
bool is_forward) {
int crop_size = crop_sizes[cur_dim];
if (cur_dim + 1 < bottom[0]->num_axes()) {
// We are not yet at the final dimension, call copy recursively
for (int i = crop_size; i < bottom[0]->shape(cur_dim) - crop_size; ++i) {
indices[cur_dim] = i;
crop_copy(bottom, top, crop_sizes, indices, cur_dim+1,
src_data, dest_data, is_forward);
}
} else {
std::vector<int> ind_red(cur_dim + 1, 0);
std::vector<int> ind_off(cur_dim + 1, 0);
for (int j = 0; j < cur_dim; ++j) {
ind_red[j] = indices[j] - crop_sizes[j];
ind_off[j] = indices[j];
}
ind_red[cur_dim] = 0;
ind_off[cur_dim] = crop_sizes[cur_dim];
// do the copy
int N = top[0]->shape(cur_dim);
if (is_forward) {
caffe_copy(N,
src_data + bottom[0]->offset(ind_off),
dest_data + top[0]->offset(ind_red));
} else {
// in the backwards pass the src_data is top_diff
// and the dest_data is bottom_diff
caffe_copy(N,
src_data + top[0]->offset(ind_red),
dest_data + bottom[0]->offset(ind_off));
}
}
}

template <typename Dtype>
void CropCenterLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
std::vector<int> indices(top[0]->num_axes(), 0);
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
crop_copy(bottom, top, crop_sizes_.cpu_data(), indices, 0, bottom_data, top_data,
true);
}

template <typename Dtype>
void CropCenterLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();

if (propagate_down[0]) {
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
std::vector<int> indices(top[0]->num_axes(), 0);
crop_copy(bottom, top, crop_sizes_.cpu_data(), indices, 0, top_diff,
bottom_diff, false);
}
}

#ifdef CPU_ONLY
STUB_GPU(CropCenterLayer);
#endif

INSTANTIATE_CLASS(CropCenterLayer);
REGISTER_LAYER_CLASS(CropCenter);

} // namespace caffe
87 changes: 87 additions & 0 deletions src/caffe/layers/crop_center_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <vector>

#include "caffe/layers/crop_center_layer.hpp"

namespace caffe {

__device__ static int compute_uncropped_index(
int index,
const int ndims,
const int* src_strides,
const int* dest_strides,
const int* crop_sizes) {
int dest_index = index;
int src_index = 0;
for (int i = 0; i < ndims; ++i) {
int coord = dest_index / dest_strides[i];
dest_index -= coord * dest_strides[i];
src_index += src_strides[i] * (coord + crop_sizes[i]);
}
return src_index;
}

template <typename Dtype>
__global__ void crop_center_kernel_forward(const int nthreads,
const int ndims,
const int* src_strides,
const int* dest_strides,
const int* crop_sizes,
const Dtype* src, Dtype* dest) {
CUDA_KERNEL_LOOP(index, nthreads) {
int src_index = compute_uncropped_index(
index, ndims, src_strides, dest_strides, crop_sizes);
dest[index] = src[src_index];
}
}

template <typename Dtype>
__global__ void crop_center_kernel_backward(const int nthreads,
const int ndims,
const int* src_strides,
const int* dest_strides,
const int* crop_sizes,
Dtype* src, const Dtype* dest) {
CUDA_KERNEL_LOOP(index, nthreads) {
int src_index = compute_uncropped_index(
index, ndims, src_strides, dest_strides, crop_sizes);
src[src_index] = dest[index];
}
}

template <typename Dtype>
void CropCenterLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
int n = top[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
crop_center_kernel_forward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
bottom[0]->num_axes(),
src_strides_.gpu_data(),
dest_strides_.gpu_data(),
crop_sizes_.gpu_data(),
bottom_data, top_data);
}

template <typename Dtype>
void CropCenterLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
int n = top[0]->count();

if (propagate_down[0]) {
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
// NOLINT_NEXT_LINE(whitespace/operators)
crop_center_kernel_backward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
bottom[0]->num_axes(),
src_strides_.gpu_data(),
dest_strides_.gpu_data(),
crop_sizes_.gpu_data(),
bottom_diff, top_diff);
}
}

INSTANTIATE_LAYER_GPU_FUNCS(CropCenterLayer);

} // namespace caffe
2 changes: 1 addition & 1 deletion src/caffe/layers/crop_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace caffe {

__device__ int compute_uncropped_index(
__device__ static int compute_uncropped_index(
int index,
const int ndims,
const int* src_strides,
Expand Down
8 changes: 7 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 149 (last added: clip_param)
// LayerParameter next available layer-specific ID: 150 (last added: crop_center_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -383,6 +383,7 @@ message LayerParameter {
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
optional CropParameter crop_param = 144;
optional CropCenterParameter crop_center_param = 149;
optional DataParameter data_param = 107;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
Expand Down Expand Up @@ -660,6 +661,11 @@ message CropParameter {
repeated uint32 offset = 2;
}

message CropCenterParameter {
// y = x[:, :, crop_size:-crop_size, crop_size:-crop_size]
repeated uint32 crop_size = 2;
}

message DataParameter {
enum DB {
LEVELDB = 0;
Expand Down
Loading

0 comments on commit 98f9aea

Please sign in to comment.