# 2. 합성곱 신경망 내부 구조

## 2.3  합성곱/풀링 구현

### 컨볼루션 층 테스트

In [3]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.layers import Convolution

image = np.arange(16).reshape(1,1,4,4)
print(image.shape)
print(image)

W = np.ones((1,1,2,2))
print(W)
b = np.full((1,), 3)
print(b)
conv = Convolution(W, b)
out = conv.forward(image)
print(out.shape)
print(out)

(1, 1, 4, 4)
[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]]]
[[[[1. 1.]
   [1. 1.]]]]
[3]
(1, 1, 3, 3)
[[[[13. 17. 21.]
   [29. 33. 37.]
   [45. 49. 53.]]]]


### im2col 함수 테스트

In [7]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.util import im2col

# x1 = np.random.rand(1, 1, 7, 7)
# col1 = im2col(x1, 5, 5)
# print(col1.shape)

# x1 = np.random.rand(1, 3, 7, 7)
# col1 = im2col(x1, 5, 5)
# print(col1.shape)

# x2 = np.random.rand(10, 3, 7, 7)
# col2 = im2col(x2, 5, 5)
# print(col2.shape)

x3 = np.random.rand(2, 4, 5, 5)
col2 = im2col(x3, 2, 2)
print(col2.shape)  #(N*OH*OW, C*FH*FW) => (2*4*4, 4*2*2)

(32, 16)


### np.pad() 의 동작

In [11]:
import numpy as np
pad=0
input_data = np.arange(16).reshape((1,1,4,4))
print(input_data)
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
print(img)

[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]]]
[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]]]


슬라이싱 과 인덱싱

In [17]:
a = np.array([1,2,3,4])
print(a, a.shape)
print(a[0], a[0].shape)
print(a[0:1], a[0:1].shape)

[1 2 3 4] (4,)
1 ()
[1] (1,)


In [20]:
a = np.array([[1,2],
              [3,4]])
print(a, a.shape)
print(a[0], a[0].shape)
print(a[0:1], a[0:1].shape)

[[1 2]
 [3 4]] (2, 2)
[1 2] (2,)
[[1 2]] (1, 2)


In [26]:
input_data = np.arange(16).reshape(1,1,4,4)
print(input_data)
print(img[:, :, 0:3:1, 0:3:1].shape)
print(img[:, :, 0:3:1, 0:3:1])

[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]]]
(1, 1, 3, 3)
[[[[ 0  1  2]
   [ 4  5  6]
   [ 8  9 10]]]]


In [28]:
col = np.zeros((1, 1, 2, 2, 3, 3))
print(col.shape)
print(col[:, :, 0, 0, :, :].shape)

(1, 1, 2, 2, 3, 3)
(1, 1, 3, 3)


In [32]:
col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
# print(col)
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
# print(col)
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
# print(col)
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]
print(col)

[[[[[[ 0.  1.  2.]
     [ 4.  5.  6.]
     [ 8.  9. 10.]]

    [[ 1.  2.  3.]
     [ 5.  6.  7.]
     [ 9. 10. 11.]]]


   [[[ 4.  5.  6.]
     [ 8.  9. 10.]
     [12. 13. 14.]]

    [[ 5.  6.  7.]
     [ 9. 10. 11.]
     [13. 14. 15.]]]]]]


### input 채널이 1개인 경우

In [40]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
pad=0
input_data = np.arange(16).reshape(1,1,4,4)
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((1, 1, 2, 2, 3, 3))

print(img.shape)
# print(img)
print(col.shape)

col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]

# print(col)
ret = col.transpose(0, 4, 5, 1, 2, 3)
# print(ret.shape)
# print(ret)
col = ret.reshape( 1*3*3, -1 )
print(col.shape)
print(col)

W = np.ones((1,1,2,2))
# print(W)
col_W = W.reshape(1, -1).T
print(col_W.shape)
b = np.full((1,),3)
out = np.dot(col, col_W) + b
print(out)
out = out.reshape(1, 3, 3, -1).transpose(0, 3, 1, 2)
print( out )

(1, 1, 4, 4)
(1, 1, 2, 2, 3, 3)
(9, 4)
[[ 0.  1.  4.  5.]
 [ 1.  2.  5.  6.]
 [ 2.  3.  6.  7.]
 [ 4.  5.  8.  9.]
 [ 5.  6.  9. 10.]
 [ 6.  7. 10. 11.]
 [ 8.  9. 12. 13.]
 [ 9. 10. 13. 14.]
 [10. 11. 14. 15.]]
(4, 1)
[[13.]
 [17.]
 [21.]
 [29.]
 [33.]
 [37.]
 [45.]
 [49.]
 [53.]]
[[[[13. 17. 21.]
   [29. 33. 37.]
   [45. 49. 53.]]]]


### input 채널이 2개인 경우

In [46]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
pad=0
input_data = np.arange(32).reshape((1,2,4,4))
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((1, 2, 2, 2, 3, 3))

print(img.shape)
print(col.shape)
col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]

# print(col)
ret = col.transpose(0, 4, 5, 1, 2, 3)
# print(ret.shape)
# print(ret)
col = ret.reshape( 1*3*3, -1 )
# print(col.shape)
# print(col)

W = np.ones((1,2,2,2))
# print(W)
col_W = W.reshape(1, -1).T
# print(col_W.shape)
b = np.full((1,),3)
out = np.dot(col, col_W) + b
out = out.reshape(1, 3, 3, -1).transpose(0, 3, 1, 2)
print( out )

(1, 2, 4, 4)
(1, 2, 2, 2, 3, 3)
[[[[ 87.  95. 103.]
   [119. 127. 135.]
   [151. 159. 167.]]]]


### 채널이 1개 , 필터가 2개인 경우

In [51]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
pad=0
input_data = np.arange(16).reshape(1,1,4,4)
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((1, 1, 2, 2, 3, 3))

# print(img.shape)
# print(img)
# print(col.shape)

col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]

# print(col)
ret = col.transpose(0, 4, 5, 1, 2, 3)
# print(ret.shape)
# print(ret)
col = ret.reshape( 1*3*3, -1 )
# print(col.shape)
# print(col)

W = np.array( [[[[1,1],
                 [1,1]]],
               [[[2,2],
                 [2,2]]]])    # (2,1,2,2)
print(W.shape)
col_W = W.reshape(2, -1).T
print(col_W.shape)
b = np.full((1,),3)
out = np.dot(col, col_W) + b
# print(out)
out = out.reshape(1, 3, 3, -1).transpose(0, 3, 1, 2)
print( out )

(2, 1, 2, 2)
(4, 2)
[[[[ 13.  17.  21.]
   [ 29.  33.  37.]
   [ 45.  49.  53.]]

  [[ 23.  31.  39.]
   [ 55.  63.  71.]
   [ 87.  95. 103.]]]]


### input 그림이 1개, 채널이 2개, 필터가 3개인 경우

In [55]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
pad=0
input_data = np.arange(32).reshape((1,2,4,4))
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((1, 2, 2, 2, 3, 3))

col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]

# print(col)
ret = col.transpose(0, 4, 5, 1, 2, 3)
# # print(ret.shape)
# # print(ret)
col = ret.reshape( 1*3*3, -1 )
print(col.shape)
# print(col)

W = np.ones((3,2,2,2))
# print(W)
col_W = W.reshape(3, -1).T
print(col_W.shape)
b = np.full((1,),3)
out = np.dot(col, col_W) + b
# print(out.shape)
out = out.reshape(1, 3, 3, -1).transpose(0, 3, 1, 2)
print( out.shape )
print(out)

(9, 8)
(8, 3)
(1, 3, 3, 3)
[[[[ 87.  95. 103.]
   [119. 127. 135.]
   [151. 159. 167.]]

  [[ 87.  95. 103.]
   [119. 127. 135.]
   [151. 159. 167.]]

  [[ 87.  95. 103.]
   [119. 127. 135.]
   [151. 159. 167.]]]]


### input 그림이 2개, 채널이 2개, 필터가 3개인 경우

In [58]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
pad=0
input_data = np.arange(64).reshape((2,2,4,4))
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((2, 2, 2, 2, 3, 3))

col[:, :, 0, 0, :, :] = img[:, :, 0:3:1, 0:3:1]
col[:, :, 0, 1, :, :] = img[:, :, 0:3:1, 1:4:1]
col[:, :, 1, 0, :, :] = img[:, :, 1:4:1, 0:3:1]
col[:, :, 1, 1, :, :] = img[:, :, 1:4:1, 1:4:1]

# print(col)
ret = col.transpose(0, 4, 5, 1, 2, 3)
# print(ret.shape)
# print(ret)
ret1 = ret.reshape( 2*3*3, -1 )
print(ret1.shape)
# print(ret1)

W = np.ones((3,2,2,2))
# # # # print(W)
W = W.reshape(3,-1).T
# # # # print(W)
out_temp = np.dot(ret1, W) 
out = out_temp.reshape(2, 3, 3, -1).transpose(0, 3, 1, 2)
print(out.shape)
print(out)

(18, 8)
(2, 3, 3, 3)
[[[[ 84.  92. 100.]
   [116. 124. 132.]
   [148. 156. 164.]]

  [[ 84.  92. 100.]
   [116. 124. 132.]
   [148. 156. 164.]]

  [[ 84.  92. 100.]
   [116. 124. 132.]
   [148. 156. 164.]]]


 [[[340. 348. 356.]
   [372. 380. 388.]
   [404. 412. 420.]]

  [[340. 348. 356.]
   [372. 380. 388.]
   [404. 412. 420.]]

  [[340. 348. 356.]
   [372. 380. 388.]
   [404. 412. 420.]]]]


### 합성곱 연산의 정확한 고찰

In [59]:
import sys, os
sys.path.append(os.pardir)
import numpy as np

filter_h=2
filter_w=2
out_h=3
out_w=3
input_data = np.arange(16).reshape((1,1,4,4))
print(input_data.shape)
col = np.zeros((1, 1,out_h, out_w, filter_h, filter_w))
print(col.shape)

for y in range(out_h):
    y_max = y + filter_h
    for x in range(out_w):
        x_max = x + filter_w
        col[:, :, y, x, :, :] = input_data[:, :, y:y_max, x:x_max]
#           1  1  2 2                      1  1   2        2

print(col)
# ret = col.transpose(0, 4, 5, 1, 2, 3)
# ret1 = ret.reshape( 3*3, -1 )
# print(ret1)

(1, 1, 4, 4)
(1, 1, 3, 3, 2, 2)
[[[[[[ 0.  1.]
     [ 4.  5.]]

    [[ 1.  2.]
     [ 5.  6.]]

    [[ 2.  3.]
     [ 6.  7.]]]


   [[[ 4.  5.]
     [ 8.  9.]]

    [[ 5.  6.]
     [ 9. 10.]]

    [[ 6.  7.]
     [10. 11.]]]


   [[[ 8.  9.]
     [12. 13.]]

    [[ 9. 10.]
     [13. 14.]]

    [[10. 11.]
     [14. 15.]]]]]]


In [60]:
import sys, os
sys.path.append(os.pardir)
import numpy as np

filter_h=2
filter_w=2
out_h=3
out_w=3
input_data = np.arange(16).reshape((1,1,4,4))
print(input_data.shape)
col = np.zeros((1, 1,filter_h, filter_w,out_h, out_w ))
print(col.shape)

for y in range(filter_h):
    y_max = y + out_h
    for x in range(filter_w):
        x_max = x + out_w
        col[:, :, y, x, :, :] = input_data[:, :, y:y_max, x:x_max]
#           1  1  3 3                      1  1   3        3

print(col)
col = col.transpose(0, 4, 5, 1, 2, 3)
print(col)

(1, 1, 4, 4)
(1, 1, 2, 2, 3, 3)
[[[[[[ 0.  1.  2.]
     [ 4.  5.  6.]
     [ 8.  9. 10.]]

    [[ 1.  2.  3.]
     [ 5.  6.  7.]
     [ 9. 10. 11.]]]


   [[[ 4.  5.  6.]
     [ 8.  9. 10.]
     [12. 13. 14.]]

    [[ 5.  6.  7.]
     [ 9. 10. 11.]
     [13. 14. 15.]]]]]]
[[[[[[ 0.  1.]
     [ 4.  5.]]]


   [[[ 1.  2.]
     [ 5.  6.]]]


   [[[ 2.  3.]
     [ 6.  7.]]]]



  [[[[ 4.  5.]
     [ 8.  9.]]]


   [[[ 5.  6.]
     [ 9. 10.]]]


   [[[ 6.  7.]
     [10. 11.]]]]



  [[[[ 8.  9.]
     [12. 13.]]]


   [[[ 9. 10.]
     [13. 14.]]]


   [[[10. 11.]
     [14. 15.]]]]]]


### 합성곱 미분

In [93]:
col = np.array(
[[ 0.,  1.,  4.,  5.],
 [ 1.,  2.,  5.,  6.],
 [ 2.,  3.,  6.,  7.],
 [ 4.,  5.,  8.,  9.],
 [ 5.,  6.,  9., 10.],
 [ 6.,  7., 10., 11.],
 [ 8.,  9., 12., 13.],
 [ 9., 10., 13., 14.],
 [10., 11., 14., 15.]])

col_W = np.ones((1,1,2,2))
# print(col_W)
dout = np.arange(1,10).reshape(1,1,3,3)
# print(dout)
dout = dout.transpose(0,2,3,1)
# print(dout)
dout = dout.reshape(-1, 1)  # dout.shape => (9,1)
# print(dout)
db = np.sum(dout, axis=0)  # db.shape => (1,)
# print(db)
dW = np.dot(col.T, dout)   # (4,9)(9,1) => (4,1)
# print(dW)
dW = dW.transpose(1, 0)
# print(dW)
dW = dW.reshape(1, 1, 2, 2)
# print(dW)
dcol = np.dot(dout, col_W.T)
# print(dcol)

dcol = dcol.reshape(1, 3, 3, 1, 2, 2)
# print(dcol)
dcol = dcol.transpose(0, 3, 4, 5, 1, 2)
# print(dcol)

print(dcol[:, :, 0, 0, :, :].shape)
dx = np.zeros((1,1,4,4))
print(dx[:, :, 0:3:1, 0:3:1].shape)
dx[:, :, 0:3:1, 0:3:1] += dcol[:, :, 0, 0, :, :]
dx[:, :, 0:3:1, 1:4:1] += dcol[:, :, 0, 1, :, :] 
dx[:, :, 1:4:1, 0:3:1] += dcol[:, :, 1, 0, :, :] 
dx[:, :, 1:4:1, 1:4:1] += dcol[:, :, 1, 1, :, :] 
print(dx)

(1, 1, 3, 3)
(1, 1, 3, 3)
[[[[ 1.  3.  5.  3.]
   [ 5. 12. 16.  9.]
   [11. 24. 28. 15.]
   [ 7. 15. 17.  9.]]]]


In [90]:
import numpy as np
def my_im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

In [91]:
def my_col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
    N, C, H, W = input_shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1
    col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)
#     print(col)
#     print(col.shape)
    img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]

    return img[:, :, pad:H + pad, pad:W + pad]

In [92]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.util import im2col, col2im

x1 = np.arange(16).reshape((1,1,4,4))
col = im2col(x1, 2, 2)
print(x1)
# print(col)

x2 = my_col2im(col,x1.shape, 2, 2)
print(x2)

[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]]]
[[[[ 0.  2.  4.  3.]
   [ 8. 20. 24. 14.]
   [16. 36. 40. 22.]
   [12. 26. 28. 15.]]]]


### 풀링 소스 분석

In [97]:
import sys, os
sys.path.append(os.pardir)
import numpy as np
from common.util import im2col
from common.util import col2im
from common.layers import Convolution

class MyPooling:
    def __init__(self, pool_h, pool_w, stride=1, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad
        
        self.x = None
        self.arg_max = None

    def forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)

        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
        print('col', col)
        col = col.reshape(-1, self.pool_h*self.pool_w)
        print('col', col)

        arg_max = np.argmax(col, axis=1)
        print('arg_max', arg_max)
        out = np.max(col, axis=1)
        print('out', out)
        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
        print('out', out)

        self.x = x
        self.arg_max = arg_max

        return out

    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)
        
        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        print('dmax', dmax)
        dmax = dmax.reshape(dout.shape + (pool_size,)) 
        
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        print('dcol', dcol)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
        
        return dx

### 채널이 1개인 경우 맥스 풀링 테스트

In [95]:
x = np.array([[[[ 7, 11, 13, 15],
                [ 3,  4,  2,  3],
                [ 1,  2, 17,  9],
                [ 1,  8,  3, 10]]]])
print(x)

[[[[ 7 11 13 15]
   [ 3  4  2  3]
   [ 1  2 17  9]
   [ 1  8  3 10]]]]


In [96]:
pool = MyPooling(2,2,2)
pool.forward(x)

col [[ 7. 11.  3.  4.]
 [13. 15.  2.  3.]
 [ 1.  2.  1.  8.]
 [17.  9.  3. 10.]]
col [[ 7. 11.  3.  4.]
 [13. 15.  2.  3.]
 [ 1.  2.  1.  8.]
 [17.  9.  3. 10.]]
arg_max [1 1 3 0]
out [11. 15.  8. 17.]
out [[[[11. 15.]
   [ 8. 17.]]]]


array([[[[11., 15.],
         [ 8., 17.]]]])

In [None]:
dout = np.array([[[[1,2],
                   [3,4]]]])
dout = dout.transpose(0, 2, 3, 1)
print(dout.shape)
print(dout.size)
print(dout.flatten())

In [None]:
pool_size = pool.pool_h * pool.pool_w
dmax = np.zeros((dout.size, pool_size))
dmax[np.arange(pool.arg_max.size), pool.arg_max.flatten()] = dout.flatten()
print(dmax)

In [None]:
dmax = dmax.reshape(dout.shape + (pool_size,)) 
print(dmax.shape)
print(dmax)

In [None]:
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
print(dcol.shape)
print(dcol)

In [None]:
dx = my_col2im(dcol, (1,1,4,4), 2,2,2,0)
print(dx)

In [None]:
dout = np.array([[[[1,2],
                   [3,4]]]])
dx = pool.backward(dout)
print(dx)

### 채널이 2개인 경우 맥스 풀링 테스트

In [98]:
x = np.array([[[[ 7, 11, 13, 15],
                [ 3,  4,  2,  3],
                [ 1,  2, 17,  9],
                [ 1,  8,  3, 10]],
               [[ 5,  8,  6,  7],
                [10,  4, 11, 13],
                [ 8,  3, 10,  4],
                [ 1,  2,  5, 15]]
              ]])
print(x.shape)

(1, 2, 4, 4)


In [99]:
pool = MyPooling(2,2,2)
out = pool.forward(x)
print(out.shape)

col [[ 7. 11.  3.  4.  5.  8. 10.  4.]
 [13. 15.  2.  3.  6.  7. 11. 13.]
 [ 1.  2.  1.  8.  8.  3.  1.  2.]
 [17.  9.  3. 10. 10.  4.  5. 15.]]
col [[ 7. 11.  3.  4.]
 [ 5.  8. 10.  4.]
 [13. 15.  2.  3.]
 [ 6.  7. 11. 13.]
 [ 1.  2.  1.  8.]
 [ 8.  3.  1.  2.]
 [17.  9.  3. 10.]
 [10.  4.  5. 15.]]
arg_max [1 2 1 3 3 0 0 3]
out [11. 10. 15. 13.  8.  8. 17. 15.]
out [[[[11. 15.]
   [ 8. 17.]]

  [[10. 13.]
   [ 8. 15.]]]]
(1, 2, 2, 2)


In [None]:
dout = np.array([[[[1,1],
                   [1,1]],
                  [[2,2],
                   [2,2]]]])
print(dout.shape)
dx = pool.backward(dout)
print(dx.shape)
print(dx)