-
Notifications
You must be signed in to change notification settings - Fork 37
/
pointwise_spatial_attention_layer.cu
221 lines (212 loc) · 10.3 KB
/
pointwise_spatial_attention_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/layers/pointwise_spatial_attention_layer.hpp"
namespace caffe {
template <typename Dtype>
__global__ void PSAForward_buffer_mask_collect_gpu(const int nthreads,
const int feature_H_, const int feature_W_,
const int mask_H_, const int mask_W_,
const int half_mask_H_, const int half_mask_W_,
const Dtype* mask_data, Dtype* buffer_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % feature_W_;
const int h = (index / feature_W_) % feature_H_;
const int n = index / feature_W_ / feature_H_;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_mask_H_ - h);
const int hend = min(mask_H_, feature_H_ + half_mask_H_ - h);
const int wstart = max(0, half_mask_W_ - w);
const int wend = min(mask_W_, feature_W_ + half_mask_W_ - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_mask_H_, widx + w - half_mask_W_) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
buffer_data[(n * feature_H_ * feature_W_ + (hidx + h - half_mask_H_) * feature_W_ + (widx + w - half_mask_W_)) * feature_H_ * feature_W_ + h * feature_W_ + w] =
mask_data[((n * mask_H_ * mask_W_ + hidx * mask_W_ + widx) * feature_H_ + h) * feature_W_ + w];
}
}
}
}
template <typename Dtype>
__global__ void PSAForward_buffer_mask_distribute_gpu(const int nthreads,
const int feature_H_, const int feature_W_,
const int mask_H_, const int mask_W_,
const int half_mask_H_, const int half_mask_W_,
const Dtype* mask_data, Dtype* buffer_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % feature_W_;
const int h = (index / feature_W_) % feature_H_;
const int n = index / feature_W_ / feature_H_;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_mask_H_ - h);
const int hend = min(mask_H_, feature_H_ + half_mask_H_ - h);
const int wstart = max(0, half_mask_W_ - w);
const int wend = min(mask_W_, feature_W_ + half_mask_W_ - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_mask_H_, widx + w - half_mask_W_) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
buffer_data[(n * feature_H_ * feature_W_ + h * feature_W_ + w) * feature_H_ * feature_W_ + (hidx + h - half_mask_H_) * feature_W_ + (widx + w - half_mask_W_)] =
mask_data[((n * mask_H_ * mask_W_ + hidx * mask_W_ + widx) * feature_H_ + h) * feature_W_ + w];
}
}
}
}
template <typename Dtype>
void PointwiseSpatialAttentionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// set mask buffer
caffe_gpu_set(mask_buffer_.count(), Dtype(0), mask_buffer_.mutable_gpu_data());
int nthreads = num_ * feature_H_ * feature_W_;
switch (this->layer_param_.pointwise_spatial_attention_param().psa_type()) {
case PointwiseSpatialAttentionParameter_PSAType_COLLECT:
PSAForward_buffer_mask_collect_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_,
bottom[1]->gpu_data(), mask_buffer_.mutable_gpu_data());
CUDA_POST_KERNEL_CHECK;
break;
case PointwiseSpatialAttentionParameter_PSAType_DISTRIBUTE:
PSAForward_buffer_mask_distribute_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_,
bottom[1]->gpu_data(), mask_buffer_.mutable_gpu_data());
CUDA_POST_KERNEL_CHECK;
break;
default:
LOG(FATAL) << "Unknown PSA type.";
}
// normalize by softmax.
if(is_softmax_) {
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
}
// aggregate
const Dtype* this_mask_data_root = NULL;
if(is_softmax_) {
this_mask_data_root = mask_buffer_prob_.gpu_data();
}
else {
this_mask_data_root = mask_buffer_.gpu_data();
}
for(int n = 0; n < num_; n++) {
const Dtype* this_bottom_data = bottom[0]->gpu_data() + bottom[0]->offset(n);
const Dtype* this_mask_data = this_mask_data_root + mask_buffer_.offset(n);
Dtype* this_top_data = top[0]->mutable_gpu_data() + top[0]->offset(n);
caffe_gpu_gemm(CblasNoTrans, CblasNoTrans,
channels_, feature_H_ * feature_W_, feature_H_ * feature_W_,
Dtype(1.0/normalization_factor_), this_bottom_data, this_mask_data, Dtype(0), this_top_data);
}
}
template <typename Dtype>
__global__ void PSABackward_buffer_mask_collect_gpu(const int nthreads,
const int feature_H_, const int feature_W_,
const int mask_H_, const int mask_W_,
const int half_mask_H_, const int half_mask_W_,
const Dtype* buffer_diff, Dtype* mask_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % feature_W_;
const int h = (index / feature_W_) % feature_H_;
const int n = index / feature_W_ / feature_H_;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_mask_H_ - h);
const int hend = min(mask_H_, feature_H_ + half_mask_H_ - h);
const int wstart = max(0, half_mask_W_ - w);
const int wend = min(mask_W_, feature_W_ + half_mask_W_ - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_mask_H_, widx + w - half_mask_W_) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
mask_diff[((n * mask_H_ * mask_W_ + hidx * mask_W_ + widx) * feature_H_ + h) * feature_W_ + w] =
buffer_diff[(n * feature_H_ * feature_W_ + (hidx + h - half_mask_H_) * feature_W_ + (widx + w - half_mask_W_)) * feature_H_ * feature_W_ + h * feature_W_ + w];
}
}
}
}
template <typename Dtype>
__global__ void PSABackward_buffer_mask_distribute_gpu(const int nthreads,
const int feature_H_, const int feature_W_,
const int mask_H_, const int mask_W_,
const int half_mask_H_, const int half_mask_W_,
const Dtype* buffer_diff, Dtype* mask_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % feature_W_;
const int h = (index / feature_W_) % feature_H_;
const int n = index / feature_W_ / feature_H_;
// effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed
const int hstart = max(0, half_mask_H_ - h);
const int hend = min(mask_H_, feature_H_ + half_mask_H_ - h);
const int wstart = max(0, half_mask_W_ - w);
const int wend = min(mask_W_, feature_W_ + half_mask_W_ - w);
// (hidx, widx ) with mask-indexed
// (hidx + h - half_mask_H_, widx + w - half_mask_W_) with feature-indexed
for (int hidx = hstart; hidx < hend; hidx++) {
for (int widx = wstart; widx < wend; widx++) {
mask_diff[((n * mask_H_ * mask_W_ + hidx * mask_W_ + widx) * feature_H_ + h) * feature_W_ + w] =
buffer_diff[(n * feature_H_ * feature_W_ + h * feature_W_ + w) * feature_H_ * feature_W_ + (hidx + h - half_mask_H_) * feature_W_ + (widx + w - half_mask_W_)];
}
}
}
}
template <typename Dtype>
void PointwiseSpatialAttentionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
// BP to feature
if (propagate_down[0]) {
const Dtype* this_mask_data_root = NULL;
if(is_softmax_) {
this_mask_data_root = mask_buffer_prob_.gpu_data();
}
else {
this_mask_data_root = mask_buffer_.gpu_data();
}
for(int n = 0; n < num_; n++) {
const Dtype* this_top_diff = top[0]->gpu_diff() + top[0]->offset(n);
const Dtype* this_mask_data = this_mask_data_root + mask_buffer_.offset(n);
Dtype* this_bottom_diff = bottom[0]->mutable_gpu_diff() + bottom[0]->offset(n);
caffe_gpu_gemm(CblasNoTrans, CblasTrans,
channels_, feature_H_ * feature_W_, feature_H_ * feature_W_,
Dtype(1.0/normalization_factor_), this_top_diff, this_mask_data, Dtype(0), this_bottom_diff);
}
}
// BP to attention
if (propagate_down[1]) {
Dtype* this_mask_diff_root = NULL;
if(is_softmax_) {
this_mask_diff_root = mask_buffer_prob_.mutable_gpu_diff();
}
else {
this_mask_diff_root = mask_buffer_.mutable_gpu_diff();
}
for(int n = 0; n < num_; n++) {
const Dtype* this_top_diff = top[0]->gpu_diff() + top[0]->offset(n);
const Dtype* this_bottom_data = bottom[0]->gpu_data() + bottom[0]->offset(n);
Dtype* this_mask_diff = this_mask_diff_root + mask_buffer_.offset(n);
caffe_gpu_gemm(CblasTrans, CblasNoTrans,
feature_H_ * feature_W_, feature_H_ * feature_W_, channels_,
Dtype(1.0/normalization_factor_), this_bottom_data, this_top_diff, Dtype(0), this_mask_diff);
}
// BP of softmax.
if(is_softmax_) {
softmax_layer_->Backward(softmax_top_vec_, softmax_propagate_down_, softmax_bottom_vec_);
}
caffe_gpu_set(bottom[1]->count(), Dtype(0), bottom[1]->mutable_gpu_diff());
int nthreads = num_ * feature_H_ * feature_W_;
switch (this->layer_param_.pointwise_spatial_attention_param().psa_type()) {
case PointwiseSpatialAttentionParameter_PSAType_COLLECT:
PSABackward_buffer_mask_collect_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_,
mask_buffer_.gpu_diff(), bottom[1]->mutable_gpu_diff());
CUDA_POST_KERNEL_CHECK;
break;
case PointwiseSpatialAttentionParameter_PSAType_DISTRIBUTE:
PSABackward_buffer_mask_distribute_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
nthreads, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_,
mask_buffer_.gpu_diff(), bottom[1]->mutable_gpu_diff());
CUDA_POST_KERNEL_CHECK;
break;
default:
LOG(FATAL) << "Unknown PSA type.";
}
}
}
INSTANTIATE_LAYER_GPU_FUNCS(PointwiseSpatialAttentionLayer);
} // namespace caffe