In [1]:
import numpy as np

In [173]:
def naive_conv2d(x, filt, stride):
    
    assert len(x.shape) == 4
    assert len(filt.shape) == 4
    assert x.shape[1] == filt.shape[2]
    
    n, c_in, h, w = x.shape
    k, _, _, c_out = filt.shape
    h_out = (h - k)//stride + 1
    w_out = (w - k)//stride + 1
    out = np.zeros((n, c_out, h_out, w_out))
    
    s = stride
    for i in range(n):
        for co in range(c_out):
            for x1 in range(0, h - k + 1, s):
                for y1 in range(0, w - k + 1, s):
                    pixel = 0
                    for m in range(k):
                        for n in range(k):
                            for ci in range(c_in):
                                pixel += x[i, ci, x1 + m, y1 + n]*filt[m, n, ci, co]
                    out[i][co][x1//s][y1//s] = pixel
    
    return out

In [224]:
k = 3
c_in = 1
c_out = 4
h = 5
w = h
n = 2
s = 1

x = np.random.randint(low=0, high=2, size=(n, c_in, h, w))
w = np.random.randn(k, k, c_in, c_out)

In [225]:
filt = np.random.randn(k, k, c_in, c_out)
out = naive_conv2d(x, filt, 1)

In [226]:
out.shape

(2, 4, 3, 3)

In [227]:
def get_patches(x, filt, stride):
    assert len(x.shape) == 4
    assert len(filt.shape) == 4
    assert x.shape[1] == filt.shape[2]
    
    n, c_in, h, w = x.shape
    k, _, _, c_out = filt.shape
    h_out = (h - k)//stride + 1
    w_out = (w - k)//stride + 1
    
    s = stride
    x_f = np.zeros((n*h_out*w_out, k*k*c_in))
    for i in range(n):
        for x1 in range(0, h - k + 1, s):
            for y1 in range(0, w - k + 1, s):
                x_f[y1//s + w_out*x1//s + i*h_out*w_out] = x[i, :, x1:x1+k, y1:y1+k].reshape(1, -1)
                
    return x_f

In [228]:
print(x.shape, filt.shape)

(2, 1, 5, 5) (3, 3, 1, 4)


In [229]:
x_f = get_patches(x, filt, 1)
x_f.shape

(18, 9)

In [230]:
def matmul_conv2d(x, filt, stride):
    
    assert len(x.shape) == 4
    assert len(filt.shape) == 4
    assert x.shape[1] == filt.shape[2]
    
    n, c_in, h, w = x.shape
    k, _, _, c_out = filt.shape
    h_out = (h - k)//stride + 1
    w_out = (w - k)//stride + 1
    
    x_f = get_patches(x, filt, stride)
    
    filt_f = filt.reshape(-1, c_out)
    out = np.matmul(x_f, filt_f)
    out = out.reshape(n, h_out, w_out, c_out)
    out = out.transpose(0, 3, 1, 2)
    return (x_f, out)

In [231]:
stride = 1
x_f, mmconv2d_out = matmul_conv2d(x, filt, stride)
mmconv2d_out.shape

(2, 4, 3, 3)

In [232]:
nconv2d_out = naive_conv2d(x, filt, 1)
nconv2d_out.shape

(2, 4, 3, 3)

In [233]:
np.allclose(mmconv2d_out, nconv2d_out)

True

In [234]:
class Conv2d:
    def __init__(self, kernel_size, in_channels, out_channels, stride):
        self.k = kernel_size
        self.c_in = in_channels
        self.c_out = out_channels
        self.s = stride
        self.w = 0.01*np.random.randn(self.k, self.k, self.c_in, self.c_out)
        
    def __call__(self, x):
        assert len(x.shape) == 4
        assert x.shape[1] == self.c_in
        
        x_f, out = matmul_conv2d(x, self.w, self.s)
        self.x_f = x_f
        return out
    
    def backward(self, G):
        grads = {}
        grads['w'] = np.matmul(self.x_f.T, G)
        grads['x'] = np.matmul(G, self.w.reshape(-1, self.c_out).T)
        return grads

In [235]:
conv1 = Conv2d(3, 1, 4, 1)
x = np.random.randn(16, 1, 32, 32)
y = conv1(x)
y.shape

(16, 4, 30, 30)

In [236]:
grads = conv1.backward(np.ones((y.shape[0]*y.shape[2]*y.shape[3], y.shape[1])))

In [237]:
grads['x'].shape

(14400, 9)

In [238]:
grads['w'].shape

(9, 4)

In [239]:
import torch
import torch.nn as nn
import time

In [240]:
x = np.random.randn(32, 32, 128, 128)
conv1 = Conv2d(3, 32, 64, 1)
tconv = nn.Conv2d(32, 64, 3, 1)
n = 3

In [241]:
t1 = time.time()
for i in range(n):
    out = conv1(x)
t2 = time.time()
print((t2 - t1)/n)

2.5316556294759116


In [242]:
x_t = torch.tensor(x).float()
t1 = time.time()
for i in range(n):
    out = tconv(x_t)
t2 = time.time()
print((t2 - t1)/n)

0.6896743774414062
