Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Research] Fastest Maxpool implementation #174

Closed
mratsim opened this issue Dec 17, 2017 · 3 comments
Closed

[Research] Fastest Maxpool implementation #174

mratsim opened this issue Dec 17, 2017 · 3 comments

Comments

@mratsim
Copy link
Owner

mratsim commented Dec 17, 2017

Max-Pooling is or at least used to be one of the key component of ConvNets.

Description from CS231n course here.
image

It is similar to convolution except that instead of doing matmul with the pooling mask, we just take the max. As such several implementations from naive to very clever exist:

Direct Max-pooling

Darknet
Caffe

void forward_maxpool_layer(const maxpool_layer l, network net)
{
    int b,i,j,k,m,n;
    int w_offset = -l.pad;
    int h_offset = -l.pad;

    int h = l.out_h;
    int w = l.out_w;
    int c = l.c;

    for(b = 0; b < l.batch; ++b){
        for(k = 0; k < c; ++k){
            for(i = 0; i < h; ++i){
                for(j = 0; j < w; ++j){
                    int out_index = j + w*(i + h*(k + c*b));
                    float max = -FLT_MAX;
                    int max_i = -1;
                    for(n = 0; n < l.size; ++n){
                        for(m = 0; m < l.size; ++m){
                            int cur_h = h_offset + i*l.stride + n;
                            int cur_w = w_offset + j*l.stride + m;
                            int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c));
                            int valid = (cur_h >= 0 && cur_h < l.h &&
                                         cur_w >= 0 && cur_w < l.w);
                            float val = (valid != 0) ? net.input[index] : -FLT_MAX;
                            max_i = (val > max) ? index : max_i;
                            max   = (val > max) ? val   : max;
                        }
                    }
                    l.output[out_index] = max;
                    l.indexes[out_index] = max_i;
                }
            }
        }
    }
}

NNPACK

static void compute_max_pooling_output(
	const struct max_pooling_output_context context[restrict static 1],
	size_t sample, size_t channel)
{
	const size_t channels                  = context->channels;
	const struct nnp_size input_size       = context->input_size;
	const struct nnp_padding input_padding = context->input_padding;
	const struct nnp_size pooling_size     = context->pooling_size;
	const struct nnp_size pooling_stride   = context->pooling_stride;
	const struct nnp_size output_size      = context->output_size;

	const float (*input)[channels][input_size.height][input_size.width] =
		(const float(*)[channels][input_size.height][input_size.width]) context->input;
	float (*output)[channels][output_size.height][output_size.width] =
		(float(*)[channels][output_size.height][output_size.width]) context->output;

	for (size_t y = 0; y < output_size.height; y++) {
		for (size_t x = 0; x < output_size.width; x++) {
			float v = -__builtin_inff();
			for (size_t i = 0; i < pooling_size.height; i++) {
				const size_t s = y * pooling_stride.height + i - input_padding.top;
				if (s < input_size.height) {
					for (size_t j = 0; j < pooling_size.width; j++) {
						const size_t t = x * pooling_stride.width + j - input_padding.left;
						if (t < input_size.width) {
							v = maxf(input[sample][channel][s][t], v);
						}
					}
				}
			}
			output[sample][channel][y][x] = v;
		}
	}
}

Neon

    def fprop_pool(self, layer, I, O, argmax=None, beta=0.0):
        """
        Forward propagate pooling layer.
        Arguments:
            layer (PoolLayer): The pool layer object, different backends have
                               different pool layers.
            I (Tensor): Input tensor.
            O (Tensor): output tensor.
            argmax (Tensor): tensor to store location of the maximum
        """

        assert layer.sizeI == I.size
        assert layer.sizeO == O.size
        if layer.op == "max":
            assert layer.sizeO == argmax.size
        op = layer.op

        J, T, R, S = layer.JTRS
        C, D, H, W, N = layer.dimI
        K, M, P, Q, N = layer.dimO
        pad_c, pad_d, pad_h, pad_w = layer.padding
        str_c, str_d, str_h, str_w = layer.strides

        array_I = I._tensor.reshape(layer.dimI)
        array_O = O._tensor.reshape(layer.dimO)
        if op == "max":
            array_argmax = argmax._tensor.reshape(layer.dimO)

        for k in range(K):
            sliceC, _ = layer.kSlice[k]

            for m in range(M):
                sliceD, _ = layer.mSlice[m]

                for p in range(P):
                    sliceH, _ = layer.pSlice[p]

                    for q in range(Q):
                        sliceW, _ = layer.qSlice[q]

                        sliceI = array_I[sliceC, sliceD, sliceH, sliceW, :].reshape(-1, N)
                        if op == "max":
                            array_argmax[k, m, p, q, :] = np.argmax(sliceI, axis=0)
                            array_O[k, m, p, q, :] = array_O[k, m, p, q, :] * beta + \
                                np.max(sliceI, axis=0)
                        elif op == "avg":
                            array_O[k, m, p, q, :] = array_O[k, m, p, q, :] * beta + \
                                np.mean(sliceI, axis=0)
                        elif op == "l2":
                            array_O[k, m, p, q, :] = array_O[k, m, p, q, :] * beta + \
                                np.sqrt(np.sum(np.square(sliceI), axis=0))

im2col and argmax based maXpooling

Important: it seems like Argmax-based solution use Numpy fancy indexing (indexing with a Tensor) which is not available in Arraymancer

https://deepnotes.io/maxpool

class Maxpool():

    def __init__(self,X_dim,size,stride):

        self.d_X, self.h_X, self.w_X = X_dim
        
        self.params = []

        self.size = size
        self.stride = stride
        
        self.h_out = (self.h_X - size)/stride + 1
        self.w_out = (self.w_X - size)/stride + 1
        

        if not self.h_out.is_integer() or not self.w_out.is_integer():
            raise Exception("Invalid dimensions!")
        
        self.h_out,self.w_out  = int(self.h_out), int(self.w_out)
        self.out_dim = (self.d_X,self.h_out,self.w_out)

    def forward(self,X):
        self.n_X = X.shape[0]
        X_reshaped = X.reshape(X.shape[0]*X.shape[1],1,X.shape[2],X.shape[3])

        self.X_col = im2col_indices(X_reshaped, self.size, self.size, padding = 0, stride = self.stride)
        
        self.max_indexes = np.argmax(self.X_col,axis=0)
        out = self.X_col[self.max_indexes,range(self.max_indexes.size)]

        out = out.reshape(self.h_out,self.w_out,self.n_X,self.d_X).transpose(2,3,0,1)
        return out

    def backward(self,dout):

        dX_col = np.zeros_like(self.X_col)
        # flatten the gradient
        dout_flat = dout.transpose(2,3,0,1).ravel()
        
        dX_col[self.max_indexes,range(self.max_indexes.size)] = dout_flat
        
        # get the original X_reshaped structure from col2im
        shape = (self.n_X*self.d_X,1,self.h_X,self.w_X)
        dX = col2im_indices(dX_col,shape,self.size,self.size,padding=0,stride=self.stride)
        dX = dX.reshape(self.n_X,self.d_X,self.h_X,self.w_X)
        return dX,[]

Chainer

    def forward_cpu(self, x):
        self._in_shape = x[0].shape
        self._in_dtype = x[0].dtype

        col = conv.im2col_cpu(
            x[0], self.kh, self.kw, self.sy, self.sx, self.ph, self.pw,
            pval=-float('inf'), cover_all=self.cover_all)
        n, c, kh, kw, out_h, out_w = col.shape
        col = col.reshape(n, c, kh * kw, out_h, out_w)

        # We select maximum twice, since the implementation using numpy.choose
        # hits its bug when kh * kw >= 32.
        self.indexes = col.argmax(axis=2)
        y = col.max(axis=2)
        return y,

https://wiseodd.github.io/techblog/2016/07/18/convnet-maxpool-layer/
And the corresponding repo

# Let say our input X is 5x10x28x28
# Our pooling parameter are: size = 2x2, stride = 2, padding = 0
# i.e. result of 10 filters of 3x3 applied to 5 imgs of 28x28 with stride = 1 and padding = 1

# First, reshape it to 50x1x28x28 to make im2col arranges it fully in column
X_reshaped = X.reshape(n * d, 1, h, w)

# The result will be 4x9800
# Note if we apply im2col to our 5x10x28x28 input, the result won't be as nice: 40x980
X_col = im2col_indices(X_reshaped, size, size, padding=0, stride=stride)

# Next, at each possible patch location, i.e. at each column, we're taking the max index
max_idx = np.argmax(X_col, axis=0)

# Finally, we get all the max value at each column
# The result will be 1x9800
out = X_col[max_idx, range(max_idx.size)]

# Reshape to the output size: 14x14x5x10
out = out.reshape(h_out, w_out, n, d)

# Transpose to get 5x10x14x14 output
out = out.transpose(2, 3, 0, 1)

Reshape based maxpooling

CS231n assignment

def max_pool_forward_naive(x, pool_param):
  """
  A naive implementation of the forward pass for a max pooling layer.
  Inputs:
  - x: Input data, of shape (N, C, H, W)
  - pool_param: dictionary with the following keys:
    - 'pool_height': The height of each pooling region
    - 'pool_width': The width of each pooling region
    - 'stride': The distance between adjacent pooling regions
  Returns a tuple of:
  - out: Output data
  - cache: (x, pool_param)
  """

  #############################################################################
  # TODO: Implement the max pooling forward pass                              #
  #############################################################################

  N, C, H, W = x.shape

  pool_height, pool_width, stride = pool_param['pool_height'], pool_param['pool_width'], pool_param['stride']

  # First validate the pooling paramters
  assert H % pool_height == 0, "Image height not divisible by pooling height"
  assert W % pool_width == 0, "Image width not divisible by pooling width"

  out = np.zeros((N, C, H / pool_height, W / pool_width))

  # Pooling layer forward using iterative method
  for ii, i in enumerate(xrange(0, H, stride)):
    for jj, j in enumerate(xrange(0, W, stride)):
      # iterate through each central point
      out[:, :, ii, jj] = np.amax( x[:, :, i:i+pool_height,j:j+pool_width].reshape(N, C, -1), axis=2)

  cache = (x, pool_param)
  return out, cache

Matlab from StackOverflow

r = 6 ,c=8

idx = kron(reshape(1:(r*c/4),c/2,[]).',ones(2))

for ii=1:number_feature_map
    data = rand(r,c);
    maxpool{ii} = reshape(accumarray(idx(:),data(:),[],@max),c/2,[]).';
end

Numpy from StackOverflow

def max_pool(x):
    """Return maximum in groups of 2x2 for a N,h,w image"""
    N,h,w = x.shape
    return np.amax([x[:,(i>>1)&1::2,i&1::2] for i in range(4)],axis=0)
def max_pool(x):
    """Return maximum in groups of 2x2 for a N,h,w image"""
    N,h,w = x.shape
    x = x.reshape(N,h/2,2,w/2,2).swapaxes(2,3).reshape(N,h/2,w/2,4)
    return np.amax(x,axis=3)
np.sum(a*W[f,:,0,0][...,None,None]+b*W[f,:,0,1][...,None,Non‌​e]+c*W[f,:,1,0][...,‌​None,None]+d*W[f,:,1‌​,1][...,None,None], axis=0)
x.reshape(N, h / 2, 2, w / 2, 2).max(axis=(2, 4))

Auto switch between reshape (square image) and im2col:

DLMatFramework

def max_pool_forward_fast(x, pool_param):
  """
  A fast implementation of the forward pass for a max pooling layer.
  This chooses between the reshape method and the im2col method. If the pooling
  regions are square and tile the input image, then we can use the reshape
  method which is very fast. Otherwise we fall back on the im2col method, which
  is not much faster than the naive method.
  """
  N, C, H, W = x.shape
  pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
  stride = pool_param['stride']

  same_size = pool_height == pool_width == stride
  tiles = H % pool_height == 0 and W % pool_width == 0
  if same_size and tiles:
    out, reshape_cache = max_pool_forward_reshape(x, pool_param)
    cache = ('reshape', reshape_cache)
  else:
    out, im2col_cache = max_pool_forward_im2col(x, pool_param)
    cache = ('im2col', im2col_cache)
  return out, cache

def max_pool_forward_reshape(x, pool_param):
  """
  A fast implementation of the forward pass for the max pooling layer that uses
  some clever reshaping.
  This can only be used for square pooling regions that tile the input.
  """
  N, C, H, W = x.shape
  pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
  stride = pool_param['stride']
  assert pool_height == pool_width == stride, 'Invalid pool params'
  assert H % pool_height == 0
  assert W % pool_height == 0
  x_reshaped = x.reshape(N, C, H / pool_height, pool_height,
                         W / pool_width, pool_width)
  out = x_reshaped.max(axis=3).max(axis=4)

  cache = (x, x_reshaped, out)
  return out, cache

def max_pool_forward_im2col(x, pool_param):
  """
  An implementation of the forward pass for max pooling based on im2col.
  This isn't much faster than the naive version, so it should be avoided if
  possible.
  """
  N, C, H, W = x.shape
  pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
  stride = pool_param['stride']

  assert (H - pool_height) % stride == 0, 'Invalid height'
  assert (W - pool_width) % stride == 0, 'Invalid width'

  out_height = (H - pool_height) / stride + 1
  out_width = (W - pool_width) / stride + 1

  x_split = x.reshape(N * C, 1, H, W)
  x_cols = im2col(x_split, pool_height, pool_width, padding=0, stride=stride)
  x_cols_argmax = np.argmax(x_cols, axis=0)
  x_cols_max = x_cols[x_cols_argmax, np.arange(x_cols.shape[1])]
  out = x_cols_max.reshape(out_height, out_width, N, C).transpose(2, 3, 0, 1)

  cache = (x, x_cols, x_cols_argmax, pool_param)
  return out, cache
@mratsim
Copy link
Owner Author

mratsim commented Dec 17, 2017

Initial implementation

proc maxpool2d*[T](input: Tensor[T],
kernel: Size2D,
padding: Size2D = (0,0),
stride: Size2D = (1,1),
argmax: var Tensor[int],
result: var Tensor[T]
) =
## MaxPool 2D forward pass
assert input.rank == 4
let
N = input.shape[0]
C = input.shape[1]
H = input.shape[2]
W = input.shape[3]
kH = kernel.height
kW = kernel.width
outH = (H + (2 * padding.height) - kH) div stride.height + 1
outW = (W + (2 * padding.width ) - kW) div stride.width + 1
channels_col = C * kH * kW
flatten_size_col = outH * outW
var x_cols = newTensorUninit[T](channels_col, flatten_size_col)
let x_split = input.reshape(N * C, 1, H, W)
im2col(x_split, (kH, kW), padding, stride, -Inf.T, x_cols) # TODO: replace by low(T) when 0.18 for https://github.com/nim-lang/Nim/commit/badba83d38371726bafba5870d5fb927eb453e41
(argmax, result) = x_cols.argmax(axis = 0)
result = result.reshape(outH, outW, N, C).permute(2, 3, 0, 1)

This has several disadvantages:

  • im2col expects CHW input and is parallelized over C.
    This reshape(N * C, 1, H, W) means the actual C = 1, so it's not parallel (and probably slow).
  • argmax (indices in im2col "domain") is unusable without input_col (the im2col'ed input). It would be much better to return argmax in the original tensor "domain".
  • We have to add an extra paddingValue parameter to im2col

It is short and easy to maintain though

@mratsim mratsim mentioned this issue Dec 18, 2017
@mratsim
Copy link
Owner Author

mratsim commented Dec 18, 2017

Maxpool implementation using reshape, it however requires fancy indexing with boolean masks for backpropagation:

proc maxpool2d_reshape*[T](input: Tensor[T],
                kernel: Size2D,
                stride: Size2D = (1,1)
                ): tuple[cached_reshaped: Tensor[T], maxpooled: Tensor[T]] {.noinit.}=
  ## Fast maxpool implementation that uses clever reshaping.
  ## This only works for square pooling regions with:
  ##  - kernel size = stride
  ##  - input is a multiple of kernel size

  let
    N = input.shape[0]
    C = input.shape[1]
    H = input.shape[2]
    W = input.shape[3]

    kH = kernel.height
    kW = kernel.width

  assert kH == kW
  assert kH == stride.height
  assert kH == stride.width
  assert H mod kH == 0
  assert W mod kW == 0

  result.cached_reshaped = input.reshape(N, C, H div kH, kH, W div kW, kW)

  result.maxpooled = result.cached_reshaped.max(axis=3).max(axis=5).squeeze

@mratsim
Copy link
Owner Author

mratsim commented Dec 18, 2017

closed by: 30bba67

@mratsim mratsim closed this as completed Dec 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant