In [1]:
import numba as nb
import numpy as np

def conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)
    return rs

@nb.jit(nopython=True)
def jit_conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)

def conv(x, w, kernel, args):
    n, n_filters = args[0], args[4]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, w, rs, *args)
    return rs

def cs231n_conv(x, w, args):
    n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w = args
    shape = (n_channels, filter_height, filter_width, n, out_h, out_w)
    strides = (height * width, width, 1, n_channels * height * width, width, 1)
    strides = x.itemsize * np.asarray(strides)
    x_cols = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).reshape(
        n_channels * filter_height * filter_width, n * out_h * out_w)
    return w.reshape(n_filters, -1).dot(x_cols).reshape(n_filters, n, out_h, out_w).transpose(1, 0, 2, 3)

# 64 个 3 x 28 x 28 的图像输入（模拟 mnist）
x = np.random.randn(64, 3, 28, 28).astype(np.float32)
# 16 个 5 x 5 的 kernel
w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)

n, n_channels, height, width = x.shape
n_filters, _, filter_height, filter_width = w.shape
out_h = height - filter_height + 1
out_w = width - filter_width + 1
args = (n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w)

print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, conv_kernel, args)).ravel()))
print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
print(np.linalg.norm((conv(x, w, conv_kernel, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
%timeit conv(x, w, conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel, args)
%timeit cs231n_conv(x, w, args)

0.00113585
0.000733545
0.00112681
3.63 s ± 194 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
300 ms ± 20.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.69 ms ± 223 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


+ 注意：这里如果使用`np.allclose`的话会过不了`assert`；事实上，仅仅是将数组的`dtype`从`float64`变成`float32`、精度就会下降很多，毕竟卷积涉及到的运算太多

In [2]:
def max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

def max_pool(x, kernel, args):
    n, n_channels = args[:2]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, rs, *args)
    return rs

pool_height, pool_width = 2, 2
n, n_channels, height, width = x.shape
out_h = height - pool_height + 1
out_w = width - pool_width + 1
args = (n, n_channels, pool_height, pool_width, out_h, out_w)

assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))
%timeit max_pool(x, max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel, args)

592 ms ± 25.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.5 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
