# Max pool

Max pooling is an operation which does the following:

* Split the input array elements into a number of equally sized groups, called "pools".
* Find the maximum element in each pool.

The pools are defined by a kernel size. The kernel function is just `max`. The kernel is swept over the array, non-overlapping, and it simply finds the maximum value of each pool.

The basic cases are described by the equations in the pytorch docs: <https://docs.pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html>

For example if we have the row-major 2x4 array:

$$
\begin{pmatrix}
0 & 1 & 2 & 3 \\
4 & 5 & 6 & 7 \\
\end{pmatrix}
$$

With kernel size 2x1, we can only fit four kernels in the array, and we have four pools:

$$
\begin{pmatrix}
0 \\
4 \\
\end{pmatrix}
\begin{pmatrix}
1 \\
5 \\
\end{pmatrix}
\begin{pmatrix}
2 \\
6 \\
\end{pmatrix}
\begin{pmatrix}
3 \\
7 \\
\end{pmatrix}
$$

The max value is chosen from each of the pools, and they are arranged in the same way the pools were arranged. So the result is:


$$
\begin{pmatrix}
4 & 5 & 6 & 7 \\
\end{pmatrix}
$$


There is also an optional parameter called dilation, which has a default of (1, 1), for 2d max pooling. Dilation tells us the stride over which the kernel spans.

For instance, given the example array from above again, let's say we want to apply a 2x2 kernel with dilation $(1, 2)$.

For the row dimension, we have 2 rows and kernel size of 2. The dilation of 1 means that the kernel looks at each successive row. So in the row dimension we can only fit the height of one kernel.

Then for the column dimension, we have 4 columns and a kernel size of 2. But the dilation of 2 means that the kernel skips every second column. So we can only fit one kernel in this dimension as well. So we just have one pool:

$$
\begin{pmatrix}
0 & 2 \\
4 & 6 \\
\end{pmatrix}
$$

The applying the max kernel to that, we get the final result:


$$
\begin{pmatrix}
6 \\
\end{pmatrix}
$$


In [1]:
import numpy as np
import itertools
import torch

Here's a small Python impl of 2d max pooling, based off of the above.

In [2]:
def max_pool2d(
    input,
    kernel_size,
    stride=None,
    padding=0,
    dilation=(1, 1),
    #ceil_mode=False,
    #return_indices=False,
):
    N, C, iH, iW = input.shape
    kH, kW = kernel_size

    if stride is None:
        stride = kernel_size

    sH, sW = stride

    oH = (iH + 2 * padding - dilation[0] * (kH - 1) - 1) // sH + 1
    oW = (iW + 2 * padding - dilation[1] * (kW - 1) - 1) // sW + 1

    output = np.empty_like(input, shape=(N, C, oH, oW))

    for h, w in itertools.product(range(oH), range(oW)):
        output[:, :, h, w] = np.max(
            input[
                :, :,
                (sH * h):(sH * h + kH * dilation[0]):dilation[0],
                (sW * w):(sW * w + kW * dilation[1]):dilation[1],
            ],
            axis=(-2, -1))

    return output

In [3]:

a = np.arange(36).reshape((1, 1, 6, 6))
kernel_size = (3, 2)
dilation = (2, 3)
b = max_pool2d(
    a,
    kernel_size,
    dilation=dilation,
)
print(a)
print(b)
c = torch.nn.functional.max_pool2d(
    torch.from_numpy(a),
    kernel_size,
    dilation=dilation,
).numpy()
print(c)
assert np.allclose(b, c)

[[[[ 0  1  2  3  4  5]
   [ 6  7  8  9 10 11]
   [12 13 14 15 16 17]
   [18 19 20 21 22 23]
   [24 25 26 27 28 29]
   [30 31 32 33 34 35]]]]
[[[[27 29]]]]
[[[[27 29]]]]
