In [152]:
import cupy as np


class Max:
    def __init__(self, axis=0):
        self.axis = axis
        self.x_shape = None
        self.o_shape = None
        self.mask = None

    def forward(self, x):
        y = np.max(x, axis=self.axis)

        self.x_shape = x.shape
        self.mask = np.argmax(x, axis=self.axis)
        self.o_shape = y.shape

        return y

    def backward(self, y):
        dx = np.zeros(self.x_shape)
        if len(self.x_shape) == 1:
            dx[self.mask] = y
        else:
            n_i = np.indices(self.o_shape)
            mask = []
            idx = 0
            for i in range(n_i.shape[0] + 1):
                if self.axis == i:
                    mask.append(self.mask)
                else:
                    mask.append(n_i[idx])
                    idx += 1
            dx[tuple(mask)] = y

        return dx

In [153]:
m = Max(axis=0)


x = np.random.randint(1, 10, 36).reshape(3, 4, 3)
x

array([[[5, 5, 3],
        [1, 8, 4],
        [3, 8, 7],
        [2, 6, 7]],

       [[8, 1, 2],
        [4, 3, 5],
        [4, 6, 1],
        [2, 6, 6]],

       [[7, 3, 7],
        [1, 9, 1],
        [6, 6, 5],
        [2, 8, 9]]])

In [154]:
y = m.forward(x)
y

array([[8, 5, 7],
       [4, 9, 5],
       [6, 8, 7],
       [2, 8, 9]], dtype=int32)

In [155]:
dx = m.backward(y)
dx

[[1 0 2]
 [1 2 1]
 [2 0 0]
 [0 2 2]]
[[[0 0 0]
  [1 1 1]
  [2 2 2]
  [3 3 3]]

 [[0 1 2]
  [0 1 2]
  [0 1 2]
  [0 1 2]]]
[array([[1, 0, 2],
       [1, 2, 1],
       [2, 0, 0],
       [0, 2, 2]], dtype=int64), array([[0, 0, 0],
       [1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]]), array([[0, 1, 2],
       [0, 1, 2],
       [0, 1, 2],
       [0, 1, 2]])]


array([[[0., 5., 0.],
        [0., 0., 0.],
        [0., 8., 7.],
        [2., 0., 0.]],

       [[8., 0., 0.],
        [4., 0., 5.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 7.],
        [0., 9., 0.],
        [6., 0., 0.],
        [0., 8., 9.]]])