In [58]:
import torch
import torch.nn.functional as F
from IPython.display import Image

## case 1

In [12]:
X = torch.arange(1, 17).reshape(4, 4).float()
X.unsqueeze_(0)
X.requires_grad_(True)

tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.],
         [13., 14., 15., 16.]]], requires_grad=True)

In [13]:
pooled_X = torch.nn.functional.max_pool2d(X, kernel_size=(2, 2), stride=2)
pooled_X

tensor([[[ 6.,  8.],
         [14., 16.]]], grad_fn=<MaxPool2DWithIndicesBackward0>)

In [14]:
y = torch.sum(pooled_X)
y.backward()


$y=\max(a,b)$

\begin{equation}
\frac{\partial y}{\partial a} = 
\begin{cases} 
1 & \text{if } a \geq b \\
0 & \text{else}
\end{cases}
\end{equation}

\begin{equation}
\frac{\partial y}{\partial b} = 
\begin{cases} 
0 & \text{if } a > b \\
1 & \text{else}
\end{cases}
\end{equation}

In [15]:
X.grad

tensor([[[0., 0., 0., 0.],
         [0., 1., 0., 1.],
         [0., 0., 0., 0.],
         [0., 1., 0., 1.]]])

## case2

In [54]:
x = torch.tensor([[[[1.0, 3.0, 2.0, 4.0],
                    [5.0, 6.0, 1.0, 2.0],
                    [7.0, 8.0, 3.0, 1.0],
                    [2.0, 4.0, 6.0, 8.0]]]])
x.requires_grad_(True)

tensor([[[[1., 3., 2., 4.],
          [5., 6., 1., 2.],
          [7., 8., 3., 1.],
          [2., 4., 6., 8.]]]], requires_grad=True)

In [55]:
y, max_indices = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

In [56]:
y, max_indices

(tensor([[[[6., 4.],
           [8., 8.]]]], grad_fn=<MaxPool2DWithIndicesBackward0>),
 tensor([[[[ 5,  3],
           [ 9, 15]]]]))

In [57]:
y.sum().backward()
x.grad

tensor([[[[0., 0., 0., 1.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.],
          [0., 0., 0., 1.]]]])

In [59]:
Image(url='../../../imgs/maxpool_grad.png', width=400)