RoIPool [12] is a standard operation for extracting a small feature map (e.g., $7 \times 7$) from each RoI. RoIPool first quantizes a floating-number RoI to the discrete granularity of the feature map, this quantized RoI is then subdivided into spatial bins which are themselves quantized, and finally feature values covered by each bin are aggregated (usually by max pooling). Quantization is performed, e.g., on a continuous coordinate x by computing [x/16], where 16 is a feature map stride and [·] is rounding; likewise, quantization is performed when dividing into bins (e.g., 7×7). These quantizations introduce misalignments between the RoI and the extracted features. While this may not impact classification, which is robust to small translations, it has a large negative effect on predicting pixel-accurate masks.

To address this, we propose an RoIAlign layer that removes the harsh quantization of RoIPool, properly aligning the extracted features with the input. Our proposed change is simple: we avoid any quantization of the RoI boundaries or bins (i.e., we use x/16 instead of [x/16]). We use bilinear interpolation [22] to compute the exact values of the input features at four regularly sampled locations in each RoI bin, and aggregate the result (using max or average), see Figure 3 for details. We note that the results are not sensitive to the exact sampling locations, or how many points are sampled, as long as no quantization is performed.

<img src='fig3_roi_align.png'>

In [93]:
from collections import namedtuple
from easydict import EasyDict
import numpy as np

In [302]:
def pre_calc_for_bilinear_interpolate(
    height, width, pooled_height, pooled_width,
    iy_upper, ix_upper,
    roi_start_h, roi_start_w,
    bin_size_h, bin_size_w,
    roi_bin_grid_h, roi_bin_grid_w,
    pre_calc):
    
    
    for ph in range(pooled_height):
        for pw in range(pooled_width):
            # For iy_upper = ix_upper = 2, iy + 0.5 = ix + 0.5 = [0.5, 1.5]
            for iy in range(iy_upper):
                yy = (roi_start_h + ph * bin_size_h + 
                    (iy + 0.5) * bin_size_h / roi_bin_grid_h)
                
                
                for ix in range(ix_upper):
                    xx = (roi_start_w + pw * bin_size_w + 
                        (ix + 0.5) * bin_size_w / roi_bin_grid_w)
                    
                    
                    print(roi_start_h, roi_start_w)
                    print(ph, pw)
                    print(bin_size_h, bin_size_w)
                    print(roi_bin_grid_h,roi_bin_grid_w)
                    print(iy, ix)
                    print(yy, xx)
                    print()
                    
                    x = xx
                    y = yy
                    
                    
                    pc = EasyDict()
                    
                    if (y < -1) or (y > height) or (x < -1) or (x > width):
                        pc.pos1 = (0, 0)
                        pc.pos2 = (0, 0)
                        pc.pos3 = (0, 0)
                        pc.pos4 = (0, 0)
                        pc.w1 = 0
                        pc.w2 = 0
                        pc.w3 = 0
                        pc.w4 = 0
                        pre_calc.append(pc);
                        continue
                        
                    if y <= 0:
                        y = 0
                    
                    if x <= 0:
                        x = 0
                        
                    y_low = int(y)
                    x_low = int(x)
                    y_high = None
                    x_high = None
                    
                    if y_low >= height - 1:
                        y_high = y_low = height - 1
                        y = y_low
                    else:
                        y_high = y_low + 1
                        
                    if x_low >= width - 1:
                        x_high = x_low = width - 1
                        x = x_low
                    else:
                        x_high = x_low + 1
                        
                    ly = y - y_low
                    lx = x - x_low
                    hy = 1. - ly
                    hx = 1. - lx
                    
                    pc.w1 = hy * hx
                    pc.w2 = hy * lx
                    pc.w3 = ly * hx
                    pc.w4 = ly * lx
                    
                    pc.pos1 = (y_low, x_low)
                    pc.pos2 = (y_low, x_high)
                    pc.pos3 = (y_high, x_low)
                    pc.pos4 = (y_high, x_high)
                    

                    
                    pre_calc.append(pc)
                    
    return pre_calc
                

In [337]:
def roi_align(bottom_data, spatial_scale,
              height, width, pooled_height, pooled_width, 
             sampling_ratio, bottom_rois):
    n_rois = len(bottom_rois)
    result = []
    
    image = bottom_data
    for n in range(n_rois):
        
        
        # Resize roi based on feature map stride
        roi_start_w = bottom_rois[n, 0] * spatial_scale
        roi_start_h = bottom_rois[n, 1] * spatial_scale
        roi_end_w = bottom_rois[n, 2] * spatial_scale
        roi_end_h = bottom_rois[n, 3] * spatial_scale
        
        # roi dims
        roi_width = max(roi_end_w - roi_start_w, 1.)
        roi_height = max(roi_end_h - roi_start_h, 1.)
        
        # fixed number of bins of different sizes
        bin_size_h = roi_height / pooled_height
        bin_size_w = roi_width / pooled_width
        
        roi_bin_grid_h = sampling_ratio if sampling_ratio > 0 \
                                else np.ceil(roi_height / pooled_height)
        
        roi_bin_grid_w = sampling_ratio if sampling_ratio > 0 \
                                else np.ceil(roi_width / pooled_width)
        
        count = roi_bin_grid_h * roi_bin_grid_w
        
        pre_calc = []
        pre_calc_for_bilinear_interpolate(
            height,
            width,
            pooled_height,
            pooled_width,
            roi_bin_grid_h,
            roi_bin_grid_w,
            roi_start_h,
            roi_start_w,
            bin_size_h,
            bin_size_w,
            roi_bin_grid_h,
            roi_bin_grid_w,
            pre_calc)
        
        res = np.zeros((pooled_height, pooled_width, image.shape[-1]))
        
        pre_calc_index = 0
            
        for ph in range(pooled_height):
            for pw in range(pooled_width):

                output_val = 0

                for iy in range(roi_bin_grid_h):
                    for ix in range(roi_bin_grid_w):
                        pc = pre_calc[pre_calc_index]
                        print(pc)
                        output_val += pc.w1 * image[pc.pos1[0], pc.pos1[1]] + \
                                pc.w2 * image[pc.pos2[0], pc.pos2[1]] + \
                                pc.w3 * image[pc.pos3[0], pc.pos3[1]] + \
                                pc.w4 * image[pc.pos4[0], pc.pos4[1]]

                        pre_calc_index += 1
                        
                    output_val =  output_val / count
                    
                res[ph, pw] = output_val
                
        result.append(res)
        
        
    return result
                        
                        
                        
            

In [359]:
bottom_data = np.reshape(np.arange(100), (2, 5, 5, 2)).astype('float32')[1]
spatial_scale = 1
height = width = 5
pooled_height = pooled_width = 4
sampling_ratio = 2
bottom_rois = np.stack([[1., 4., 4., 4]]).astype('float32') 


In [360]:
a = roi_align(bottom_data, spatial_scale, height, width, 
              pooled_height, pooled_width, 
             sampling_ratio, bottom_rois)

4.0 1.0
0 0
0.25 0.75
2 2
0 0
4.0625 1.1875

4.0 1.0
0 0
0.25 0.75
2 2
0 1
4.0625 1.5625

4.0 1.0
0 0
0.25 0.75
2 2
1 0
4.1875 1.1875

4.0 1.0
0 0
0.25 0.75
2 2
1 1
4.1875 1.5625

4.0 1.0
0 1
0.25 0.75
2 2
0 0
4.0625 1.9375

4.0 1.0
0 1
0.25 0.75
2 2
0 1
4.0625 2.3125

4.0 1.0
0 1
0.25 0.75
2 2
1 0
4.1875 1.9375

4.0 1.0
0 1
0.25 0.75
2 2
1 1
4.1875 2.3125

4.0 1.0
0 2
0.25 0.75
2 2
0 0
4.0625 2.6875

4.0 1.0
0 2
0.25 0.75
2 2
0 1
4.0625 3.0625

4.0 1.0
0 2
0.25 0.75
2 2
1 0
4.1875 2.6875

4.0 1.0
0 2
0.25 0.75
2 2
1 1
4.1875 3.0625

4.0 1.0
0 3
0.25 0.75
2 2
0 0
4.0625 3.4375

4.0 1.0
0 3
0.25 0.75
2 2
0 1
4.0625 3.8125

4.0 1.0
0 3
0.25 0.75
2 2
1 0
4.1875 3.4375

4.0 1.0
0 3
0.25 0.75
2 2
1 1
4.1875 3.8125

4.0 1.0
1 0
0.25 0.75
2 2
0 0
4.3125 1.1875

4.0 1.0
1 0
0.25 0.75
2 2
0 1
4.3125 1.5625

4.0 1.0
1 0
0.25 0.75
2 2
1 0
4.4375 1.1875

4.0 1.0
1 0
0.25 0.75
2 2
1 1
4.4375 1.5625

4.0 1.0
1 1
0.25 0.75
2 2
0 0
4.3125 1.9375

4.0 1.0
1 1
0.25 0.75
2 2
0 1
4.3125 2.3125

4.0 1.0
1 

In [361]:
a[0]

array([[[57.96875, 58.59375],
        [58.90625, 59.53125],
        [59.84375, 60.46875],
        [60.78125, 61.40625]],

       [[57.96875, 58.59375],
        [58.90625, 59.53125],
        [59.84375, 60.46875],
        [60.78125, 61.40625]],

       [[57.96875, 58.59375],
        [58.90625, 59.53125],
        [59.84375, 60.46875],
        [60.78125, 61.40625]],

       [[57.96875, 58.59375],
        [58.90625, 59.53125],
        [59.84375, 60.46875],
        [60.78125, 61.40625]]])