In [1]:
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()

In [8]:
def corr2d(X, K):
    h, w =K.shape
    Y = np.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j: j + w] * K).sum()
    return Y

In [9]:
X = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
K = np.array([[0, 1], [2, 3]])
corr2d(X ,K)

array([[19., 25.],
       [37., 43.]])

### CNN

In [10]:
class Conv2D(nn.Block):
    def __init__(self, kernel_size, **kwargs):
        super(Conv2D, self).__init__(**kwargs)
        
        self.weight = self.params.get('weight', shape= kernel_size)
        self.bias = self.params.get('bias', shape= (1,))
        
    def forward(self, x):
        return corr2d(x, self.weight.data() + self.bias.data())

In [11]:
X = np.ones((6, 8))
X[:, 2:6] = 0
X

array([[1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.]])

In [12]:
K = np.array([[1, -1]])
Y = corr2d(X, K)
Y

array([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])

In [13]:
corr2d(X.T, K)

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

In [15]:
conv2d = nn.Conv2D(1, kernel_size= (1,2))
conv2d.initialize()

X = X.reshape(1, 1, 6, 8)
Y = Y.reshape(1, 1, 6, 7)

for i in range(10):
    with autograd.record():
        Y_hat = conv2d(X)
        l = (Y_hat - Y) ** 2
    l.backward()
    conv2d.weight.data()[:] -= 3e-2 * conv2d.weight.grad()
    
    if (i+1) % 2 == 0:
        print('batch %d loss: %.3f' % (i + 1, l.sum()))

batch 2 loss: 4.949
batch 4 loss: 0.831
batch 6 loss: 0.140
batch 8 loss: 0.024
batch 10 loss: 0.004
