diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu index 55df914ab0..f62554d0ef 100644 --- a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu @@ -11,907 +11,250 @@ *************************************************************************/ #include "common_mlu_helper.hpp" -__nram__ char buffer[MAX_NRAM_SIZE]; - -#define ALIGN_SIZE 64 -#define BUFFER_SIZE (MAX_NRAM_SIZE * 480 / 512) #define ROI_OFFSET 5 -#define SAMPLING_NUM 4 -#define DIM_BOX 5 -#define BLOCK_INPUT_OUTPUT 2 +__nram__ char buffer[MAX_NRAM_SIZE]; namespace forward { template -__mlu_func__ void bilinearInterpolate( - T *tmp_sum, T *nram_in, T *offset_bottom_data, const int roi_bin_grid_h, - const int roi_bin_grid_w, const T bin_size_h, const T bin_size_w, - const int input_height, const int input_width, const int channels, - const int channel_align, const int cyc_channel, T y_pre, T x_pre, - T zero_sign_tmp, bool is_normal_c, int index) { - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - T y = (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)) <= 0.0 - ? 0.0 - : (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)); - int y_low = int(y); - int y_high; - if (y_low >= input_height - 1) { - y_high = y_low = input_height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - T ly = y - y_low; - T hy = 1.0 - ly; - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - T x = (x_pre + ((ix + 0.5) * bin_size_w) / (T)(roi_bin_grid_w)) <= 0.0 - ? 0.0 - : (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)); - T zero_sign = - (T)(x >= -1.0 && x <= input_width && y >= -1.0 && y <= input_height) * - zero_sign_tmp; - int x_low = int(x); - int x_high; - if (x_low >= input_width - 1) { - x_high = x_low = input_width - 1; - x = T(x_low); - } else { - x_high = x_low + 1; - } - T lx = x - x_low; - T hx = 1.0 - lx; - - T w1 = hy * hx * zero_sign; - T w2 = hy * lx * zero_sign; - T w3 = ly * hx * zero_sign; - T w4 = ly * lx * zero_sign; - - // load - int cpy_len = (x_high - x_low) * channels; - int temp_size = cyc_channel < (channels - index * cyc_channel) - ? cyc_channel - : channels - index * cyc_channel; - int cpy_size = is_normal_c ? channels * sizeof(T) : temp_size * sizeof(T); - - int32_t offset1 = (y_low * input_width + x_low) * channels; - int32_t offset2 = (y_high * input_width + x_low) * channels; - - T *tmp1 = is_normal_c - ? offset_bottom_data + offset1 - : offset_bottom_data + offset1 + cyc_channel * index; - T *tmp2 = is_normal_c - ? offset_bottom_data + offset2 - : offset_bottom_data + offset2 + cyc_channel * index; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + cyc_channel; - T *tmp_cyc3 = nram_in + cyc_channel * 2; - T *tmp_cyc4 = nram_in + cyc_channel * 3; - - __asm__ volatile("sync;"); - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - __nramset(nram_in, channel_align, T(0)); - } else { - __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); - __asm__ volatile("sync;"); - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); - __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, 1, - SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); - } +__mlu_func__ void bilinearInterpolate(const int input_height, + const int input_width, T y, T x, T *w1, + T *w2, T *w3, T *w4, int *x_low, + int *x_high, int *y_low, int *y_high, + bool *empty) { + // deal with cases that inverse elements are of feature map boundary + if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { + *empty = true; + return; } -} -template -__mlu_func__ void roialignForwardNpartKernel( - T *input, T *rois, T *output, T *nram_buffer, const bool aligned, - const int channels, const int pooled_height, const int pooled_width, - const int input_height, const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, const int max_elements) { - /* - * NRAM partition - * |----------------------NRAM------------------------| - * | | - * | output | - * |--------------------------------------------------| - * | | - * | input | - * | | - * |--------------------------------------------------| - * | rois(batch_id, x1, y1, x2, y2) | - * |--------------------------------------------------| - * - * channel data will loop inside of input_nram, when channel * size(T) > - * input_nram - */ - - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; + if (y <= 0) y = 0; + if (x <= 0) x = 0; - // multi-core params - int inter_num = num_rois / taskDim; - int rem_num = num_rois % taskDim; - int offset_length; - int task_length; + int y_low_ = int(y); + int x_low_ = int(x); - // the length dealt by every core and the offset of taskId - if (taskId < rem_num) { - task_length = inter_num + 1; - offset_length = taskId * (inter_num + 1); + if (y_low_ >= input_height - 1) { + *y_high = y_low_ = input_height - 1; + y = (T)y_low_; } else { - task_length = inter_num; - offset_length = rem_num * (inter_num + 1) + (taskId - rem_num) * inter_num; - } - - int max_size = max_elements; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size * 2; - - int pooled_size = pooled_height * pooled_width; - T *top_data = output + offset_length * pooled_size * channels; - T *task_rois = rois + offset_length * ROI_OFFSET; - - for (int roi_id = 0; roi_id < task_length; roi_id++) { - // For each roi, find the corresponding feature map which it belongs to, - // and compute the scaling_factor to map it to that feature map. - T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = task_rois + roi_id * ROI_OFFSET; - - int batch_id = roi_id_tmp[0]; - T roi_xmin = roi_id_tmp[1]; - T roi_ymin = roi_id_tmp[2]; - T roi_xmax = roi_id_tmp[3]; - T roi_ymax = roi_id_tmp[4]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - roi_width = roi_width > 1.0 ? roi_width : 1.0; - roi_height = roi_height > 1.0 ? roi_height : 1.0; - } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); - - for (int ph = 0; ph < pooled_height; ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, channel_align, channel_align, - y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = - samp_channel / (max_elements * SAMPLING_NUM) + - (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); - int cyc_channel = max_elements; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph - } // loop for num_roi -} - -template -__mlu_func__ void roialignForwardHpartKernel( - T *input, T *rois, T *output, T *nram_buffer, const bool aligned, - const int channels, const int pooled_height, const int pooled_width, - const int input_height, const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, const int max_elements) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int taskdim_cyc = taskDim / num_rois > 1 ? taskDim / num_rois : 1; - int roi_id = taskId / taskdim_cyc; - if (taskId >= taskdim_cyc * num_rois) { - return; + *y_high = y_low_ + 1; } - // multi-core params - int inter_num = pooled_height / taskdim_cyc; - int rem_num = pooled_height % taskdim_cyc; - int offset_length; - int task_length; - - if ((taskId % taskdim_cyc) < rem_num) { - task_length = inter_num + 1; - offset_length = (taskId % taskdim_cyc) * (inter_num + 1); + if (x_low_ >= input_width - 1) { + *x_high = x_low_ = input_width - 1; + x = T(x_low_); } else { - task_length = inter_num; - offset_length = rem_num * (inter_num + 1) + - ((taskId % taskdim_cyc) - rem_num) * inter_num; - } - - int max_size = max_elements * 2; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - - int pooled_size = pooled_height * pooled_width; - T *top_data = - output + (roi_id * pooled_size + offset_length * pooled_width) * channels; - T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = rois + roi_id * ROI_OFFSET; - - int batch_id = roi_id_tmp[0]; - T roi_xmin = roi_id_tmp[1]; - T roi_ymin = roi_id_tmp[2]; - T roi_xmax = roi_id_tmp[3]; - T roi_ymax = roi_id_tmp[4]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; + *x_high = x_low_ + 1; } - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); + *y_low = y_low_; + *x_low = x_low_; - for (int ph = offset_length; ph < (offset_length + task_length); ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, - bin_size_w, input_height, input_width, channels, - channel_align, channel_align, y_pre, x_pre, - zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / (max_elements * SAMPLING_NUM) + - (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); - int cyc_channel = max_elements; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph -} - -__mlu_global__ void MLUUnion1KernelRoialign( - const void *input, const void *rois, const int channels, const bool aligned, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const cnrtDataType_t data_type, void *output) { - size_t data_type_size = - (data_type == CNRT_FLOAT32) ? sizeof(float) : sizeof(half); - int max_elements = PAD_DOWN( - (BUFFER_SIZE / (int)data_type_size) / (ROI_OFFSET + 1), ALIGN_SIZE); - - if (taskDim < num_rois || (num_rois * pooled_height < taskDim)) { - switch (data_type) { - case CNRT_FLOAT16: { - half *nram_buffer = (half *)buffer; - roialignForwardNpartKernel( - (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_elements); - }; break; - case CNRT_FLOAT32: { - float *nram_buffer = (float *)buffer; - roialignForwardNpartKernel( - (float *)input, (float *)rois, (float *)output, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); - }; break; - default: - break; - } - } else { - switch (data_type) { - case CNRT_FLOAT16: { - half *nram_buffer = (half *)buffer; - roialignForwardHpartKernel( - (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_elements); - }; break; - case CNRT_FLOAT32: { - float *nram_buffer = (float *)buffer; - roialignForwardHpartKernel( - (float *)input, (float *)rois, (float *)output, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); - }; break; - default: - break; - } - } + T ly = y - y_low_; + T lx = x - x_low_; + T hy = 1.0 - ly; + T hx = 1.0 - lx; + *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; return; } template -__mlu_func__ void buSelection(T *rois_count, T *nram_temp, const int num_rois) { - for (int i = 0; i < num_rois; ++i) { - for (int j = 1; j < num_rois; ++j) { - if (rois_count[(j - 1) * 2] < rois_count[j * 2]) { - nram_temp[0] = rois_count[(j - 1) * 2]; - rois_count[(j - 1) * 2] = rois_count[j * 2]; - rois_count[j * 2] = nram_temp[0]; - nram_temp[1] = rois_count[(j - 1) * 2 + 1]; - rois_count[(j - 1) * 2 + 1] = rois_count[j * 2 + 1]; - rois_count[j * 2 + 1] = nram_temp[1]; - } - } - } -} - -template -__mlu_func__ void getPatitionList(T *h_nram, T *n_nram, T *roi_count, - int pooled_height, int num_rois, T sum, - int split_num, int &h_flag, int &n_flag) { - T avg_sum = sum / split_num; - T *h_nram_temp = h_nram; - T *n_nram_temp = n_nram; - - int n_index = 0; - T n_sum = 0; - h_flag = 0; - n_flag = 0; - int list_align = PAD_UP(ALIGN_SIZE * 5, ALIGN_SIZE); - __bang_write_zero(h_nram, list_align); - for (int i = 0; i < num_rois; i++) { - if (roi_count[2 * i] >= avg_sum) { - int h_num = std::ceil(roi_count[2 * i] / avg_sum); - int h_split = pooled_height / h_num; - int h_rem = pooled_height % h_num; - T h_sum = 0.0; - - for (int j = 0; j < h_num; j++) { - h_nram_temp[0] = i; - h_nram_temp[1] = h_sum; - h_nram_temp[2] = (j < h_rem) ? (h_split + 1) : h_split; - h_sum += h_nram_temp[2]; - h_nram_temp += 3; - n_nram_temp += 2; - h_flag++; - } - } else { - if (roi_count[2 * i] + n_sum > avg_sum) { - n_nram_temp[0] = i - n_index; - n_nram_temp[1] = i - 1; - n_sum = 0.0; - n_index = 0; - n_nram_temp += 2; - i--; - n_flag++; - } else { - n_index++; - n_sum += roi_count[2 * i]; - } - } - } - if (n_flag == 0 && n_index != 0) { - n_flag = 1; - n_nram[(h_flag + n_flag - 1) * 2] = num_rois - 1; - } - - n_nram[(h_flag + n_flag) * 2 - 1] = num_rois - 1; - - if (h_flag + n_flag > taskDim) { - getPatitionList(h_nram, n_nram, roi_count, pooled_height, num_rois, sum, - split_num - 1, h_flag, n_flag); - } - return; +__mlu_func__ void computeChannel(T *input_core, T *nram_in, T *output_core, + T *nram_out, const int roi_bin_grid_h, + const int roi_bin_grid_w, const T roi_start_h, + const T roi_start_w, const int ph, + const int pw, const T bin_size_h, + const T bin_size_w, const float count, + const int input_height, const int input_width, + const int channels, const int cyc_num, + const int max_elements) { + int cyc_channel = max_elements; + + for (int i = 0; i < cyc_num; i++) { + int real_channel = + (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; + int align_channel = PAD_UP(real_channel, NFU_ALIGN_SIZE / sizeof(T)); + __bang_write_zero(nram_out, align_channel); + uint32_t real_size = real_channel * sizeof(T); + + int iy, ix; + for (iy = 0; iy < roi_bin_grid_h; iy++) { + // 1. compute the coordinates of the y axis in the current roi_bin_grid_h + T y = roi_start_h + ph * bin_size_h + + (T)(iy + 0.5) * bin_size_h / (T)(roi_bin_grid_h); + for (ix = 0; ix < roi_bin_grid_w; ix++) { + // 2. compute the coordinates of the x axis in the current + // roi_bin_grid_w + T x = roi_start_w + pw * bin_size_w + + (T)(ix + 0.5) * bin_size_w / (T)(roi_bin_grid_w); + + // 3. compute the four weights (w1, w2, w3 and w4), the height (y_low + // and y_high) and weight (x_low and x_high) of input feature map in + // the current roi bin grid, and the flag (empty) which shows if x, y + // are out of input feature map ranges + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bool empty = false; + + bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, &w3, &w4, + &x_low, &x_high, &y_low, &y_high, &empty); + + // 4. compute interpolation of the current roi bin grid + // tmp_cyc1, temp_cyc2, tmp_cyc3 and tmp_cyc4 store the input values + // to compute the interpolation, and then reused to compute + // the argmax_x and argmax_y. + T *tmp_cyc1 = nram_in + cyc_channel; + T *tmp_cyc2 = nram_in + cyc_channel * 2; + T *tmp_cyc3 = nram_in + cyc_channel * 3; + T *tmp_cyc4 = nram_in + cyc_channel * 4; + + if (empty) { // exits abnormal values + __bang_write_zero(nram_in, align_channel); + } else { + __bang_write_zero(nram_in, align_channel); + uint32_t offset1 = (y_low * input_width + x_low) * channels; + uint32_t offset2 = (y_low * input_width + x_high) * channels; + uint32_t offset3 = (y_high * input_width + x_low) * channels; + uint32_t offset4 = (y_high * input_width + x_high) * channels; + T *input1 = (T *)input_core + offset1 + i * cyc_channel; + T *input2 = (T *)input_core + offset2 + i * cyc_channel; + T *input3 = (T *)input_core + offset3 + i * cyc_channel; + T *input4 = (T *)input_core + offset4 + i * cyc_channel; + + // load the four pixels (p1, p2, p3 and p4) of input feature map to + // compute interpolation + __memcpy(tmp_cyc1, input1, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc2, input2, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc3, input3, real_size, GDRAM2NRAM); + __memcpy(tmp_cyc4, input4, real_size, GDRAM2NRAM); + + // interpolation value = w1 * p1 + w2 * p2 + w3 * p3 + w4 * p4 + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, align_channel); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, align_channel); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, align_channel); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, align_channel); + + __bang_add(nram_in, tmp_cyc1, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc2, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc3, nram_in, align_channel); + __bang_add(nram_in, tmp_cyc4, nram_in, align_channel); + } + // 5. compute sum value and corresponding coordinates of x axis and y + // axis. Update the sum value. + __bang_add(nram_out, nram_in, nram_out, align_channel); + } // loop_roi_grid_w + } // loop_roi_grid_h + T count_value = (T)(1.0 / count); + __bang_mul_const(nram_out, nram_out, count_value, align_channel); + __memcpy(output_core + i * cyc_channel, nram_out, real_size, NRAM2GDRAM); + } // loop_cyc_num } template -__mlu_func__ void mergeAndSplitQuantity( - T *rois, T *rois_sort, T *split_list, T *roi_count, T *nram_rois, - const bool aligned, const int pooled_height, const int pooled_width, - const int sampling_ratio, const float spatial_scale, const int num_rois, - int &h_split_num, int &n_split_num) { - /* take the coordinates out of ROIS and actually calculate the actual - * calculation size. The sorted calculation scale is partition, large scale - * is split H, small is N. - */ - T *h_tem = split_list; - T *n_tem = split_list + 3 * ALIGN_SIZE; - int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 1), ALIGN_SIZE); - int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); - __bang_write_zero(nram_rois, num_rois_align); - T sum = 0.0; - int temp_offset = 0; - __memcpy((void *)(nram_rois + 1), (void *)rois, ROI_OFFSET * sizeof(T), - GDRAM2NRAM, (ROI_OFFSET + 1) * sizeof(T), ROI_OFFSET * sizeof(T), - (num_rois - 1)); - T *nram_temp = roi_count + count_align; - for (int roi_id = 0; roi_id < num_rois; roi_id++) { - T offset = aligned ? (T)0.5 : (T)0; - - T roi_xmin = nram_rois[temp_offset + 2]; - T roi_ymin = nram_rois[temp_offset + 3]; - T roi_xmax = nram_rois[temp_offset + 4]; - T roi_ymax = nram_rois[temp_offset + 5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; +__mlu_func__ void roialignForwardAvg( + T *input, T *rois, T *output, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const T spatial_scale, + const int num_rois) { + // find limit for channel, the nram space is divided to 6 parts that are + // input, 4 weights to compute the interpolation (w1, w2, w3, w4), output + + // max_elements : 300 : float datatype : 27296, half datatype : 54592 + // max_elements : 200 : float datatype : 16384, half datatype : 32768 + int max_elements = (PAD_DOWN(MAX_NRAM_SIZE / 6, NFU_ALIGN_SIZE)) / sizeof(T); + int cyc_num = channels / max_elements + (int)(channels % max_elements != 0); + T offset = aligned ? (T)0.5 : (T)0.0; + int task_num = num_rois * pooled_height * pooled_width; + T *nram_out = (T *)buffer; + T *nram_in = nram_out + max_elements; + if (task_num < taskDim) { + if (taskId >= task_num) { + return; } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - sum += count; - *(roi_count + 2 * roi_id) = count; - *(roi_count + 2 * roi_id + 1) = roi_id; - - *(nram_rois + roi_id * (ROI_OFFSET + 1)) = count; - temp_offset += (ROI_OFFSET + 1); } - buSelection(roi_count, nram_temp, num_rois); - - temp_offset = 0; - for (int i = 0; i < num_rois; i++) { - for (int j = 0; j < num_rois; j++) { - if (roi_count[2 * i] == nram_rois[j * (ROI_OFFSET + 1)]) { - rois_sort[temp_offset] = nram_rois[j * (ROI_OFFSET + 1)]; - rois_sort[temp_offset + 1] = nram_rois[j * (ROI_OFFSET + 1) + 1]; - rois_sort[temp_offset + 2] = nram_rois[j * (ROI_OFFSET + 1) + 2]; - rois_sort[temp_offset + 3] = nram_rois[j * (ROI_OFFSET + 1) + 3]; - rois_sort[temp_offset + 4] = nram_rois[j * (ROI_OFFSET + 1) + 4]; - rois_sort[temp_offset + 5] = nram_rois[j * (ROI_OFFSET + 1) + 5]; - nram_rois[j * (ROI_OFFSET + 1)] = -1.0; - break; - } + for (int bin_idx = taskId; bin_idx < task_num; bin_idx = bin_idx + taskDim) { + if (bin_idx >= task_num) { + return; } - temp_offset += (ROI_OFFSET + 1); - } - getPatitionList(h_tem, n_tem, roi_count, pooled_height, num_rois, sum, - taskDim, h_split_num, n_split_num); -} - -template -__mlu_func__ void roialignForwardNpartKernelForBinPart( - T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, - T *nram_buffer, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_size) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int max_elements = max_size * SAMPLING_NUM; - int offset_length; - int task_length; - - T *n_split_nram = split_list + 3 * ALIGN_SIZE + 2 * taskId; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - T *task_rois = rois_sort + (int)n_split_nram[0] * (ROI_OFFSET + 1); - - offset_length = (int)n_split_nram[0]; - task_length = n_split_nram[1] - n_split_nram[0] + 1; - int pooled_size = pooled_height * pooled_width; - - for (int roi_id = offset_length; roi_id < offset_length + task_length; - roi_id++) { - // For each roi, find the corresponding feature map which it belongs to, - // and compute the scaling_factor to map it to that feature map. - T offset = aligned ? (T)0.5 : (T)0; - int rea_out_id = rois_count[roi_id * 2 + 1]; - T *top_data = output + rea_out_id * pooled_size * channels; - T *nram_rois = task_rois + (roi_id - offset_length) * (ROI_OFFSET + 1); - int batch_id = nram_rois[1]; - T roi_xmin = nram_rois[2]; - T roi_ymin = nram_rois[3]; - T roi_xmax = nram_rois[4]; - T roi_ymax = nram_rois[5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; + // (n,ph.pw) is a c in the pooled output + int pw = bin_idx % pooled_width; + int ph = (bin_idx / pooled_width) % pooled_height; + int n = bin_idx / pooled_width / pooled_height; + + T *roi_id_tmp = rois + n * ROI_OFFSET; + // 1. compute width and height of roi region. + int batch_idx = (int)roi_id_tmp[0]; + T roi_x1 = roi_id_tmp[1]; + T roi_y1 = roi_id_tmp[2]; + T roi_x2 = roi_id_tmp[3]; + T roi_y2 = roi_id_tmp[4]; + T roi_start_w = roi_x1 * spatial_scale - offset; + T roi_start_h = roi_y1 * spatial_scale - offset; + T roi_end_w = roi_x2 * spatial_scale - offset; + T roi_end_h = roi_y2 * spatial_scale - offset; + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1.0 ? roi_width : 1.0; - roi_height = roi_height > 1.0 ? roi_height : 1.0; + roi_width = roi_width > (T)(1.0) ? roi_width : (T)(1.0); + roi_height = roi_height > (T)(1.0) ? roi_height : (T)(1.0); } - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_in, max_elements); - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < max_elements; - - for (int ph = 0; ph < pooled_height; ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, channel_align, channel_align, - y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / max_elements + - (int)(samp_channel % max_elements != 0); - int cyc_channel = max_elements / SAMPLING_NUM; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph - } // loop for num_roi -} - -template -__mlu_func__ void roialignForwardHpartKernelForBinPart( - T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, - T *nram_buffer, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_size) { - int channel_align = PAD_UP(channels, ALIGN_SIZE); - int samp_channel_align = channel_align * SAMPLING_NUM; - int samp_channel = channels * SAMPLING_NUM; - int max_elements = max_size * SAMPLING_NUM; - - T *h_split_nram = split_list; - T *nram_out = nram_buffer; - T *nram_in = nram_out + max_size; - T *nram_rois = rois_sort + (int)h_split_nram[taskId * 3] * (ROI_OFFSET + 1); - - int offset_length = (int)h_split_nram[taskId * 3 + 1]; - int task_length = (int)h_split_nram[taskId * 3 + 2]; - int rea_out_id = (int)h_split_nram[taskId * 3]; - - rea_out_id = rois_count[rea_out_id * 2 + 1]; - int pooled_size = pooled_height * pooled_width; - T *top_data = - output + - (rea_out_id * pooled_size + offset_length * pooled_width) * channels; - - T offset = aligned ? (T)0.5 : (T)0; - - int batch_id = nram_rois[1]; - T roi_xmin = nram_rois[2]; - T roi_ymin = nram_rois[3]; - T roi_xmax = nram_rois[4]; - T roi_ymax = nram_rois[5]; - - roi_xmin = roi_xmin * (T)spatial_scale - offset; - roi_ymin = roi_ymin * (T)spatial_scale - offset; - roi_xmax = roi_xmax * (T)spatial_scale - offset; - roi_ymax = roi_ymax * (T)spatial_scale - offset; - - T roi_width = roi_xmax - roi_xmin; - T roi_height = roi_ymax - roi_ymin; - - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; + // 2. compute float-type width and height of roi bin region. + T bin_size_w = (T)roi_width / (T)pooled_width; + T bin_size_h = (T)roi_height / (T)pooled_height; + + // 3. compute int-type width and height of roi bin region. + int roi_bin_grid_h, roi_bin_grid_w; + roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : int(ceilf(roi_height / pooled_height)); + roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : int(ceilf(roi_width / pooled_width)); + float count = (float)((roi_bin_grid_h * roi_bin_grid_w) > 1 + ? roi_bin_grid_h * roi_bin_grid_w + : 1.0); + T *input_core = input + batch_idx * channels * input_width * input_height; + T *output_core = output + bin_idx * channels; + // 4. compute avg value and corresponding coordinates of x axis and y axis. + computeChannel(input_core, nram_in, output_core, nram_out, roi_bin_grid_h, + roi_bin_grid_w, roi_start_h, roi_start_w, ph, pw, bin_size_h, + bin_size_w, count, input_height, input_width, channels, + cyc_num, max_elements); } - - T bin_size_h = roi_height / (T)pooled_height; - T bin_size_w = roi_width / (T)pooled_width; - T *offset_bottom_data = - input + batch_id * channels * input_width * input_height; - - T *tmp_sum = nram_out; - __bang_write_zero(nram_in, max_elements); - __bang_write_zero(nram_out, max_size); - - // We use roi_bin_grid to sample the grid, and perform average pooling - // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_h)); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : (int)std::ceil((float)(bin_size_w)); - T count = roi_bin_grid_h * roi_bin_grid_w; - T zero_sign_tmp = 1.0f / count; - bool is_normal_c = samp_channel_align < max_elements; - - for (int ph = offset_length; ph < (offset_length + task_length); ph++) { - T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid - for (int pw = 0; pw < pooled_width; pw++) { - T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid - // Bilinear interpolatation - if (is_normal_c) { - bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, - bin_size_w, input_height, input_width, channels, - channel_align, channel_align, y_pre, x_pre, - zero_sign_tmp, is_normal_c, 0); - } else { - // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / max_elements + - (int)(samp_channel % max_elements != 0); - int cyc_channel = max_elements / SAMPLING_NUM; - for (int i = 0; i < cyc_num; ++i) { - int real_channel = cyc_channel < (channels - i * cyc_channel) - ? cyc_channel - : channels - i * cyc_channel; - int align_channel = (i == cyc_num - 1) - ? PAD_UP(real_channel, ALIGN_SIZE) - : cyc_channel; - bilinearInterpolate( - (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, - roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, - input_height, input_width, channels, align_channel, cyc_channel, - y_pre, x_pre, zero_sign_tmp, is_normal_c, i); - - __memcpy(top_data + cyc_channel * i, tmp_sum, - real_channel * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - } - - // copy output data to ddr when channel num is not aligned with 64 - if (is_normal_c) { - __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); - __bang_write_zero(nram_out, max_size); - } - top_data += channels; - } // loop for pw - } // loop for ph } -__mlu_global__ void MLUUnion1KernelBinPartRoialign( +__mlu_global__ void MLUUnion1KernelRoiAlignAvg( const void *input, const void *rois, const int channels, const bool aligned, const int pooled_height, const int pooled_width, const int input_height, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, const cnrtDataType_t data_type, void *output) { - int h_split_num = 0; - int n_split_num = 0; - int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 4), ALIGN_SIZE); - int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); - int list_align = ALIGN_SIZE * 5; - int sum_size = num_rois_align + count_align + list_align; - + // make sure that memcore is not used if (coreId == 0x80) { return; } switch (data_type) { case CNRT_FLOAT16: { - int max_channel = - PAD_DOWN((BUFFER_SIZE / sizeof(half) - sum_size) / (ROI_OFFSET + 1), - ALIGN_SIZE); - half *rois_sort = (half *)buffer; - __bang_write_zero(rois_sort, sum_size); - half *rois_count = (half *)(rois_sort + num_rois_align); - half *split_list = (half *)(rois_count + count_align); - half *nram_rois = (half *)(split_list + list_align); - mergeAndSplitQuantity((half *)rois, (half *)rois_sort, (half *)split_list, - (half *)rois_count, (half *)nram_rois, aligned, - pooled_height, pooled_width, sampling_ratio, - spatial_scale, num_rois, h_split_num, n_split_num); - half *nram_buffer = (half *)nram_rois; - __bang_write_zero(nram_rois, num_rois_align); - - if (taskId < h_split_num) { - roialignForwardHpartKernelForBinPart( - (half *)input, (half *)rois, (half *)output, (half *)rois_sort, - (half *)split_list, (half *)rois_count, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_channel); - } else { - if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { - roialignForwardNpartKernelForBinPart( - (half *)input, (half *)rois, (half *)output, (half *)rois_sort, - (half *)split_list, (half *)rois_count, (half *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, - max_channel); - } else { - return; - } - } + roialignForwardAvg((half *)input, (half *)rois, (half *)output, aligned, + channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, + (half)spatial_scale, num_rois); }; break; case CNRT_FLOAT32: { - int max_channel = - PAD_DOWN((BUFFER_SIZE / sizeof(float) - sum_size) / (ROI_OFFSET + 1), - ALIGN_SIZE); - float *rois_sort = (float *)buffer; - __bang_write_zero(rois_sort, sum_size); - float *rois_count = (float *)(rois_sort + num_rois_align); - float *split_list = (float *)(rois_count + count_align); - float *nram_rois = (float *)(split_list + list_align); - mergeAndSplitQuantity((float *)rois, (float *)rois_sort, - (float *)split_list, (float *)rois_count, - (float *)nram_rois, aligned, pooled_height, - pooled_width, sampling_ratio, spatial_scale, - num_rois, h_split_num, n_split_num); - float *nram_buffer = (float *)nram_rois; - __bang_write_zero(nram_rois, num_rois_align); - - if (taskId < h_split_num) { - roialignForwardHpartKernelForBinPart( - (float *)input, (float *)rois, (float *)output, (float *)rois_sort, - (float *)split_list, (float *)rois_count, (float *)nram_buffer, - aligned, channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, max_channel); - } else { - if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { - roialignForwardNpartKernelForBinPart( - (float *)input, (float *)rois, (float *)output, - (float *)rois_sort, (float *)split_list, (float *)rois_count, - (float *)nram_buffer, aligned, channels, pooled_height, - pooled_width, input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_channel); - } else { - return; - } - } + roialignForwardAvg((float *)input, (float *)rois, (float *)output, + aligned, channels, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, + (float)spatial_scale, num_rois); }; break; default: break; } + return; } } // namespace forward @@ -1131,21 +474,9 @@ void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, void *output) { - // set thresholds for degradation caused by sorting - const int sort_border = 100; // threshold of num_rois - const int sort_cluster_num = 16; // threshold of cluster - - if (num_rois > sort_border || k_dim.y < sort_cluster_num) { - forward::MLUUnion1KernelRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, spatial_scale, num_rois, - d_type, output); - } else { - forward::MLUUnion1KernelBinPartRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, spatial_scale, num_rois, - d_type, output); - } + forward::MLUUnion1KernelRoiAlignAvg<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); } void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,