Skip to content

Commit

Permalink
update roialign cuda impl to onnx opset16 (#12036)
Browse files Browse the repository at this point in the history
* roialign opset16

* fix

* fix
  • Loading branch information
jywu-msft authored and RandyShuai committed Jul 6, 2022
1 parent babd325 commit 2de1e4f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Status RoiAlign<T>::ComputeInternal(OpKernelContext* context) const {
num_roi_cols,
reinterpret_cast<typename ToCudaType<T>::MappedType*>(Y.template MutableData<T>()),
this->mode_ == RoiAlignMode::avg,
this->half_pixel_,
batch_indices_ptr->template Data<int64_t>());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ __global__ void RoIAlignForward(
int64_t roi_cols,
T* top_data,
const bool is_mode_avg,
const bool half_pixel,
const int64_t* batch_indices_ptr) {
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) {
// (n, c, ph, pw) is an element in the pooled output
Expand All @@ -106,17 +107,16 @@ __global__ void RoIAlignForward(
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
const auto roi_batch_ind = batch_indices_ptr[n];

bool continuous_coordinate = false;
// Do not using rounding; this implementation detail is critical
T roi_offset = continuous_coordinate ? T(0.5) : T(0);
T roi_offset = half_pixel ? T(0.5) : T(0);
T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!continuous_coordinate) { // backward compatiblity
if (!half_pixel) { // backward compatiblity
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
Expand Down Expand Up @@ -188,6 +188,7 @@ void RoiAlignImpl(
int64_t roi_cols,
T* top_data,
const bool is_mode_avg,
const bool half_pixel,
const int64_t* batch_indices_ptr) {
int blocksPerGrid = (int)(ceil(static_cast<float>(nthreads) / GridDim::maxThreadsPerBlock));
RoIAlignForward<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
Expand All @@ -204,6 +205,7 @@ void RoiAlignImpl(
roi_cols,
top_data,
is_mode_avg,
half_pixel,
batch_indices_ptr);
}

Expand All @@ -223,6 +225,7 @@ void RoiAlignImpl(
int64_t roi_cols, \
T* top_data, \
const bool is_mode_avg, \
const bool half_pixel, \
const int64_t* batch_indices_ptr);

SPECIALIZED_IMPL(float)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void RoiAlignImpl(
int64_t roi_cols,
T* top_data,
const bool is_mode_avg,
const bool half_pixel,
const int64_t* batch_indices_ptr);

} // namespace cuda
Expand Down

0 comments on commit 2de1e4f

Please sign in to comment.