In [1]:
import numpy as np

In [2]:
image = np.arange(28*28*16).reshape(16,1,28,28)

In [3]:
class naive_Conv():
    """
    Conv layer
    """
    def __init__(self, Cin, Cout, F, stride=1, padding=0, bias=True):
        self.Cin = Cin
        self.Cout = Cout
        self.F = F
        self.S = stride
        self.W = {'val': np.random.normal(0.0,np.sqrt(2/Cin),(Cout,Cin,F,F)), 'grad': 0} 
        self.b = {'val': np.random.randn(Cout), 'grad': 0}
        self.cache = None
        self.pad = padding

    def forward(self, X):
        X = np.pad(X, ((0,0),(0,0),(self.pad,self.pad),(self.pad,self.pad)), 'constant')
        (N, Cin, H, W) = X.shape
        H_ = H - self.F + 1
        W_ = W - self.F + 1
        Y = np.zeros((N, self.Cout, H_, W_))

        for n in range(N):
            for c in range(self.Cout):
                for h in range(H_):
                    for w in range(W_):
                        Y[n, c, h, w] = np.sum(X[n, :, h:h+self.F, w:w+self.F] * self.W['val'][c, :, :, :]) + self.b['val'][c]

        self.cache = X
        return Y

    def backward(self, dout):
        X = self.cache
        (N, Cin, H, W) = X.shape
        H_ = H - self.F + 1
        W_ = W - self.F + 1
        W_rot = np.rot90(np.rot90(self.W['val']))

        dX = np.zeros(X.shape)
        dW = np.zeros(self.W['val'].shape)
        db = np.zeros(self.b['val'].shape)

        # dW
        for co in range(self.Cout):
            for ci in range(Cin):
                for h in range(self.F):
                    for w in range(self.F):
                        dW[co, ci, h, w] = np.sum(X[:,ci,h:h+H_,w:w+W_] * dout[:,co,:,:])

        # db
        for co in range(self.Cout):
            db[co] = np.sum(dout[:,co,:,:])

        dout_pad = np.pad(dout, ((0,0),(0,0),(self.F,self.F),(self.F,self.F)), 'constant')

        # dX
        for n in range(N):
            for ci in range(Cin):
                for h in range(H):
                    for w in range(W):
                        dX[n, ci, h, w] = np.sum(W_rot[:,ci,:,:] * dout_pad[n, :, h:h+self.F,w:w+self.F])

        return dX

In [4]:
naive_conv = naive_Conv(1,6,5)
%timeit naive_conv.forward(image)

484 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
out = naive_conv.forward(image)
%timeit naive_conv.backward(out)

132 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
class im2col_Conv():
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.weight_size = (self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        self.stride = stride

        self.W = {'val': np.random.standard_normal(self.weight_size), 'grad': 0}
        self.b = {'val':  np.random.standard_normal((self.out_channels,1)), 'grad': 0}
        
        self.cache = None

    def forward(self, x):
        (N,C,H,W) = x.shape
        self.input_shape = x.shape
        H_out = (H - self.kernel_size) // self.stride + 1
        W_out = (W - self.kernel_size) // self.stride + 1
        conv_out = np.zeros((N, self.out_channels, H_out, W_out))
        self.col_image = []
        
        weight_cols = self.W['val'].reshape(self.out_channels, -1)
        for i in range(N):
            img_i = x[i][np.newaxis, :]
            x_cols = im2col(img_i, self.kernel_size, self.stride)
            conv_out[i] = (np.dot(weight_cols, x_cols.T) + self.b['val']).reshape(self.out_channels, H_out, W_out)
            self.col_image.append(x_cols)
        self.col_image = np.array(self.col_image)
        
        return conv_out
        
    def backward(self, error):
        (N,C,_,_) = error.shape
        error_col = error.reshape(N,C,-1)
        for i in range(N):
            self.W['grad'] += np.dot(error_col[i], self.col_image[i]).reshape(self.W['val'].shape)
        self.b['grad'] += np.sum(error_col, axis=(0,2)).reshape(self.b['val'].shape)
        
        error_pad =np.pad(error, ((0,0), (0,0), (self.kernel_size - 1, self.kernel_size - 1),
                          (self.kernel_size - 1, self.kernel_size - 1)), 'constant', constant_values=0)
        
        flip_weights = self.W['val'][:, :, ::-1, ::-1]
        flip_weights = flip_weights.swapaxes(0,1) 
        col_flip_weights = flip_weights.reshape(self.in_channels, -1)
        
        col_pad_delta = np.array([im2col(error_pad[i][np.newaxis, :], self.kernel_size, self.stride) for i in range(N)])
#         col_pad_delta = split_by_strides(error_pad,self.kernel_size,self.kernel_size,self.stride)
        next_delta = np.dot(col_pad_delta, col_flip_weights.T)
        next_delta = np.reshape(next_delta.transpose(0,2,1), self.input_shape)
        
        return next_delta
    
def im2col(image, kernel_size, stride):
    (N, C, H, W) = image.shape
    image_col = []
    for i in range(0, H - kernel_size + 1, stride):
        for j in range(0, W - kernel_size + 1, stride):
            col = image[:, :, i:i+kernel_size, j:j+kernel_size].reshape(-1)
            image_col.append(col)
    image_col = np.array(image_col)
    return image_col

In [7]:
im2col_conv = im2col_Conv(1,6,5)
%timeit im2col_conv.forward(image)

28.4 ms ± 544 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%timeit im2col_conv.backward(out)

66.2 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
class fastim2col_Conv():
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.weight_size = (self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        self.stride = stride

        self.W = {'val': np.random.standard_normal(self.weight_size), 'grad': 0}
        self.b = {'val':  np.random.standard_normal((self.out_channels,1)), 'grad': 0}
        
        self.cache = None

    def forward(self, x):
        (N,C,H,W) = x.shape
        self.input_shape = x.shape
        H_out = (H - self.kernel_size) // self.stride + 1
        W_out = (W - self.kernel_size) // self.stride + 1
        conv_out = np.zeros((N, self.out_channels, H_out, W_out))
        self.col_image = []
        
        weight_cols = self.W['val'].reshape(self.out_channels, -1)
        self.col_image = im2col_fast(x,self.kernel_size,self.stride)
        conv_out = (np.dot(self.col_image, weight_cols.T) + self.b['val'].T).transpose(0,2,1).reshape(N, self.out_channels, H_out, W_out)
        
        return conv_out
        
    def backward(self, error):
        (N,C,_,_) = error.shape
        error_col = error.reshape(N,C,-1)
        for i in range(N):
            self.W['grad'] += np.dot(error_col[i], self.col_image[i]).reshape(self.W['val'].shape)
        self.b['grad'] += np.sum(error_col, axis=(0,2)).reshape(self.b['val'].shape)
        
        error_pad =np.pad(error, ((0,0), (0,0), (self.kernel_size - 1, self.kernel_size - 1),
                          (self.kernel_size - 1, self.kernel_size - 1)), 'constant', constant_values=0)
        
        flip_weights = self.W['val'][:, :, ::-1, ::-1]
        flip_weights = flip_weights.swapaxes(0,1) # hard to make sure
        col_flip_weights = flip_weights.reshape(self.in_channels, -1)
        
#         col_pad_delta = np.array([im2col(error_pad[i][np.newaxis, :], self.kernel_size, self.stride) for i in range(N)])
        col_pad_delta = im2col_fast(error_pad,self.kernel_size,self.stride)
        next_delta = np.dot(col_pad_delta, col_flip_weights.T)
        next_delta = np.reshape(next_delta.transpose(0,2,1), self.input_shape)
        
        return next_delta

def im2col_fast(image, kernel_size, stride):
    N, C, H, W = image.shape
    H_out = (H - kernel_size) // stride + 1
    W_out = (W - kernel_size) // stride + 1
    shape = (N, C, H_out, W_out, kernel_size, kernel_size)
    strides = (*image.strides[:-2], image.strides[-2]*stride, image.strides[-1]*stride, *image.strides[-2:])
    A = np.lib.stride_tricks.as_strided(image, shape=shape, strides=strides)
    return A.transpose(0,2,3,1,4,5).reshape(N, H_out*W_out, C*kernel_size*kernel_size)

In [10]:
fastim2col_conv = fastim2col_Conv(1,6,5)
%timeit fastim2col_conv.forward(image)

3.63 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
%timeit fastim2col_conv.backward(out)

19.6 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
