In [None]:
### from https://github.com/mratsim/Arraymancer/issues/174

def max_pool_forward_fast(x, pool_param):
    """
    A fast implementation of the forward pass for a max pooling layer.
    This chooses between the reshape method and the im2col method. If the pooling
    regions are square and tile the input image, then we can use the reshape
    method which is very fast. Otherwise we fall back on the im2col method, which
    is not much faster than the naive method.
    """
    N, C, H, W = x.shape
    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
    stride = pool_param['stride']

    same_size = pool_height == pool_width == stride
    tiles = H % pool_height == 0 and W % pool_width == 0
    if same_size and tiles:
    out, reshape_cache = max_pool_forward_reshape(x, pool_param)
    cache = ('reshape', reshape_cache)
    else:
    out, im2col_cache = max_pool_forward_im2col(x, pool_param)
    cache = ('im2col', im2col_cache)
    return out, cache

def max_pool_forward_reshape(x, pool_param):
    """
    A fast implementation of the forward pass for the max pooling layer that uses
    some clever reshaping.
    This can only be used for square pooling regions that tile the input.
    """
    N, C, H, W = x.shape
    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
    stride = pool_param['stride']
    assert pool_height == pool_width == stride, 'Invalid pool params'
    assert H % pool_height == 0
    assert W % pool_height == 0
    x_reshaped = x.reshape(N, C, H / pool_height, pool_height,
                         W / pool_width, pool_width)
    out = x_reshaped.max(axis=3).max(axis=4)

    cache = (x, x_reshaped, out)
    return out, cache

def max_pool_forward_im2col(x, pool_param):
    """
    An implementation of the forward pass for max pooling based on im2col.
    This isn't much faster than the naive version, so it should be avoided if
    possible.
    """
    N, C, H, W = x.shape
    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
    stride = pool_param['stride']

    assert (H - pool_height) % stride == 0, 'Invalid height'
    assert (W - pool_width) % stride == 0, 'Invalid width'

    out_height = (H - pool_height) / stride + 1
    out_width = (W - pool_width) / stride + 1

    x_split = x.reshape(N * C, 1, H, W)
    x_cols = im2col(x_split, pool_height, pool_width, padding=0, stride=stride)
    x_cols_argmax = np.argmax(x_cols, axis=0)
    x_cols_max = x_cols[x_cols_argmax, np.arange(x_cols.shape[1])]
    out = x_cols_max.reshape(out_height, out_width, N, C).transpose(2, 3, 0, 1)

    cache = (x, x_cols, x_cols_argmax, pool_param)
    return out, cache

In [4]:
import numpy as np

In [1]:
def max_pool_forward_reshape(x, pool_param):
    """
    A fast implementation of the forward pass for the max pooling layer that uses
    some clever reshaping.
    This can only be used for square pooling regions that tile the input.
    """
    N, C, H, W = x.shape
    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
    stride = pool_param['stride']
    assert pool_height == pool_width == stride, 'Invalid pool params'
    assert H % pool_height == 0
    assert W % pool_height == 0
    x_reshaped = x.reshape(N, C, H / pool_height, pool_height,
                         W / pool_width, pool_width)
    out = x_reshaped.max(axis=3).max(axis=4)

    cache = (x, x_reshaped, out)
    return out, cache

In [38]:
def mp(x):
    
    N, C, H, W = x.shape
    
    temp = 2
    
    pool_height = temp
    pool_width = temp
    stride = temp
    
    x_reshaped = x.reshape(N, C, H // pool_height, pool_height,
                         W // pool_width, pool_width)
    out = x_reshaped.max(axis=3).max(axis=4)

    cache = (x, x_reshaped, out)
    return out#, cache

In [39]:
x0 = np.arange(16).reshape(4, 4)
x1 = (np.arange(16) + 16).reshape(4, 4)
x2 = (np.arange(16) + 32).reshape(4, 4)
x = np.zeros((1, 3, 4, 4))
x[0, 0, :, :] = x0
x[0, 1, :, :] = x1
x[0, 2, :, :] = x2

In [40]:
x0 = np.array(np.random.rand(4, 4) * 10, dtype=np.int32)
x1 = np.array(np.random.rand(4, 4) * 100, dtype=np.int32)
x2 = np.array(np.random.rand(4, 4) * 1000, dtype=np.int32)
x = np.zeros((1, 3, 4, 4))
x[0, 0, :, :] = x0
x[0, 1, :, :] = x1
x[0, 2, :, :] = x2

In [41]:
x

array([[[[  4.,   1.,   4.,   0.],
         [  0.,   4.,   0.,   2.],
         [  8.,   0.,   3.,   1.],
         [  7.,   5.,   3.,   8.]],

        [[ 73.,  13.,  71.,  22.],
         [ 14.,   5.,  31.,  12.],
         [ 44.,  72.,  32.,  40.],
         [ 93.,  86.,  67.,  72.]],

        [[735., 759., 311., 965.],
         [591.,  64., 971., 400.],
         [689., 932., 983., 574.],
         [239., 752.,  33., 220.]]]])

In [42]:
mp(x)

array([[[[  4.,   4.],
         [  8.,   8.]],

        [[ 73.,  71.],
         [ 93.,  72.]],

        [[759., 971.],
         [932., 983.]]]])

In [66]:
def mp(x):
    
    C, H, W = x.shape
    
    temp = 2
    
    pool_height = temp
    pool_width = temp
    stride = temp
    
    x_reshaped = x.reshape(C, H // pool_height, pool_height,
                         W // pool_width, pool_width)
    out = x_reshaped.max(axis=2).max(axis=3)

    cache = (x, x_reshaped, out)
    return out#, cache

In [67]:
x0 = np.arange(16).reshape(4, 4)
x1 = (np.arange(16) + 16).reshape(4, 4)
x2 = (np.arange(16) + 32).reshape(4, 4)
x = np.zeros((3, 4, 4))
x[0, :, :] = x0
x[1, :, :] = x1
x[2, :, :] = x2

In [47]:
x0 = np.array(np.random.rand(4, 4) * 10, dtype=np.int32)
x1 = np.array(np.random.rand(4, 4) * 100, dtype=np.int32)
x2 = np.array(np.random.rand(4, 4) * 1000, dtype=np.int32)
x = np.zeros((3, 4, 4))
x[0, :, :] = x0
x[1, :, :] = x1
x[2, :, :] = x2

In [68]:
x

array([[[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]],

       [[16., 17., 18., 19.],
        [20., 21., 22., 23.],
        [24., 25., 26., 27.],
        [28., 29., 30., 31.]],

       [[32., 33., 34., 35.],
        [36., 37., 38., 39.],
        [40., 41., 42., 43.],
        [44., 45., 46., 47.]]])

In [69]:
mp(x)

array([[[ 5.,  7.],
        [13., 15.]],

       [[21., 23.],
        [29., 31.]],

       [[37., 39.],
        [45., 47.]]])

In [92]:
# this one!

def mp(x):
    
    H, W, C = x.shape
    
    temp = 2
    
    pool_height = temp
    pool_width = temp
    stride = temp
    
    x_reshaped = x.reshape(H // pool_height, pool_height,
                         W // pool_width, pool_width, C)
    out = x_reshaped.max(axis=1).max(axis=2)

    cache = (x, x_reshaped, out)
    return out#, cache

In [93]:
x0 = np.arange(16).reshape(4, 4)
x1 = (np.arange(16) + 16).reshape(4, 4)
x2 = (np.arange(16) + 32).reshape(4, 4)
x = np.zeros((4, 4, 3))
x[:, :, 0] = x0
x[:, :, 1] = x1
x[:, :, 2] = x2

In [96]:
x0 = np.array(np.random.rand(4, 4) * 10, dtype=np.int32)
x1 = np.array(np.random.rand(4, 4) * 100, dtype=np.int32)
x2 = np.array(np.random.rand(4, 4) * 1000, dtype=np.int32)
x = np.zeros((4, 4, 3))
x[:, :, 0] = x0
x[:, :, 1] = x1
x[:, :, 2] = x2

In [94]:
x

array([[[ 0., 16., 32.],
        [ 1., 17., 33.],
        [ 2., 18., 34.],
        [ 3., 19., 35.]],

       [[ 4., 20., 36.],
        [ 5., 21., 37.],
        [ 6., 22., 38.],
        [ 7., 23., 39.]],

       [[ 8., 24., 40.],
        [ 9., 25., 41.],
        [10., 26., 42.],
        [11., 27., 43.]],

       [[12., 28., 44.],
        [13., 29., 45.],
        [14., 30., 46.],
        [15., 31., 47.]]])

In [95]:
mp(x)

array([[[ 5., 21., 37.],
        [ 7., 23., 39.]],

       [[13., 29., 45.],
        [15., 31., 47.]]])

In [97]:
x

array([[[  2.,  41., 659.],
        [  9.,  58., 678.],
        [  3.,  60., 441.],
        [  0.,   9., 222.]],

       [[  3.,  99.,  60.],
        [  5.,  32., 207.],
        [  3.,  73., 633.],
        [  9.,  27., 668.]],

       [[  7.,  94., 178.],
        [  6.,  58., 605.],
        [  6.,  87., 393.],
        [  1.,  77., 624.]],

       [[  0.,  14.,  36.],
        [  2.,  80., 587.],
        [  2.,  66., 607.],
        [  4.,  38., 893.]]])

In [98]:
mp(x)

array([[[  9.,  99., 678.],
        [  9.,  73., 668.]],

       [[  7.,  94., 605.],
        [  6.,  87., 893.]]])

In [105]:
def max_pool(patch_ir):
    pool_size = 8
    
    height, width, nr_channels = patch_ir.shape
    
    patch_ir_reshaped = patch_ir.reshape(height // pool_size, pool_size,
                           width // pool_size, pool_size, nr_channels)
    patch_descr = patch_ir_reshaped.max(axis=1).max(axis=2)

    return patch_descr

In [106]:
patch_ir = np.zeros((8, 16, 32))

In [107]:
max_pool(patch_ir)

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [None]:
# the case where it doesn't tile perfectly (slower implementation)

In [None]:
def mp_slow(x, pool_param):

    N, C, H, W = x.shape
    pool_height, pool_width = pool_param['pool_height'], pool_param['pool_width']
    stride = pool_param['stride']

    assert (H - pool_height) % stride == 0, 'Invalid height'
    assert (W - pool_width) % stride == 0, 'Invalid width'

    out_height = (H - pool_height) / stride + 1
    out_width = (W - pool_width) / stride + 1

    x_split = x.reshape(N * C, 1, H, W)
    x_cols = im2col(x_split, pool_height, pool_width, padding=0, stride=stride)
    x_cols_argmax = np.argmax(x_cols, axis=0)
    x_cols_max = x_cols[x_cols_argmax, np.arange(x_cols.shape[1])]
    out = x_cols_max.reshape(out_height, out_width, N, C).transpose(2, 3, 0, 1)

    cache = (x, x_cols, x_cols_argmax, pool_param)
    return out, cache