In [57]:
import numpy as np
import torch
import torch.nn as nn

For `DATA` We use NHWC, while torch use NCWH 

For `WEIGHT` We use [KERNEL_SIZE][KERNEL_SIZE][IN_CHANNELS][OUT_CHANNELS], while torch use out x in x kernel x kernel

In [58]:
def conv_reference(Z, weight):
  # NHWC -> NCHW
  Z_torch = torch.tensor(Z).permute(0,3,1,2)
  # KKIO -> OIKK
  W_torch = torch.tensor(weight).permute(3,2,0,1)
  # run convolution
  out = nn.functional.conv2d(Z_torch, W_torch)
  # NCHW -> NHWC
  return out.permute(0,2,3,1).contiguous().numpy()

In [59]:
Z = np.random.randn(4,8,8,3) # NHWC
W = np.random.randn(3,3,3,4) # KKIO

In [60]:
out_ref = conv_reference(Z, W)
print(out_ref.shape)

(4, 6, 6, 4)


In [61]:
def conv_mul(Z : np.ndarray, weight : np.ndarray):
  out_shape = (Z.shape[0], Z.shape[1] - weight.shape[0] + 1, Z.shape[2] - weight.shape[1] + 1, weight.shape[3])
  out = np.zeros(out_shape)
  for x in range(weight.shape[0]):
    for y in range(weight.shape[1]):
      out += Z[:, x:x+out_shape[1], y:y+out_shape[2], :] @ weight[x, y, :, :]
  return out

In [62]:
out_conv_mul = conv_mul(Z, W)
print(np.linalg.norm(out_ref - out_conv_mul))

2.6108350672945522e-14


In [68]:
Z = np.arange(6*6, dtype=np.float32).reshape(6,6)
W = np.arange(3*3, dtype=np.float32).reshape(3,3)

In [64]:
B = np.lib.stride_tricks.as_strided(Z,shape=(2,2,3,3), strides=np.array((18,3,6,1))*4)
B = np.ascontiguousarray(B)
B

array([[[[ 0.,  1.,  2.],
         [ 6.,  7.,  8.],
         [12., 13., 14.]],

        [[ 3.,  4.,  5.],
         [ 9., 10., 11.],
         [15., 16., 17.]]],


       [[[18., 19., 20.],
         [24., 25., 26.],
         [30., 31., 32.]],

        [[21., 22., 23.],
         [27., 28., 29.],
         [33., 34., 35.]]]], dtype=float32)

In [65]:
C = np.lib.stride_tricks.as_strided(Z,shape=(4,4,3,3), strides=np.array((6,1,6,1))*4)
C = np.ascontiguousarray(C)
C

array([[[[ 0.,  1.,  2.],
         [ 6.,  7.,  8.],
         [12., 13., 14.]],

        [[ 1.,  2.,  3.],
         [ 7.,  8.,  9.],
         [13., 14., 15.]],

        [[ 2.,  3.,  4.],
         [ 8.,  9., 10.],
         [14., 15., 16.]],

        [[ 3.,  4.,  5.],
         [ 9., 10., 11.],
         [15., 16., 17.]]],


       [[[ 6.,  7.,  8.],
         [12., 13., 14.],
         [18., 19., 20.]],

        [[ 7.,  8.,  9.],
         [13., 14., 15.],
         [19., 20., 21.]],

        [[ 8.,  9., 10.],
         [14., 15., 16.],
         [20., 21., 22.]],

        [[ 9., 10., 11.],
         [15., 16., 17.],
         [21., 22., 23.]]],


       [[[12., 13., 14.],
         [18., 19., 20.],
         [24., 25., 26.]],

        [[13., 14., 15.],
         [19., 20., 21.],
         [25., 26., 27.]],

        [[14., 15., 16.],
         [20., 21., 22.],
         [26., 27., 28.]],

        [[15., 16., 17.],
         [21., 22., 23.],
         [27., 28., 29.]]],


       [[[18., 19., 20.],
        

In [69]:
print(C.shape, W.shape)

(4, 4, 3, 3) (3, 3)


In [70]:
(C.reshape(-1,9) @ W.reshape(9, -1)).reshape(4,4)

array([[ 366.,  402.,  438.,  474.],
       [ 582.,  618.,  654.,  690.],
       [ 798.,  834.,  870.,  906.],
       [1014., 1050., 1086., 1122.]], dtype=float32)

In [71]:
Z = np.random.randn(4,8,8,3) # NHWC
W = np.random.randn(3,3,3,4) # KKIO

In [72]:
def conv_im2col(Z : np.ndarray, weight : np.ndarray):
  N, H, W, C = Z.shape
  K1, K2, INC, OUTC = weight.shape
  assert K1 == K2 and C == INC
  K = K1

  A = np.lib.stride_tricks.as_strided(
    Z,
    shape=(N, H-K+1, W-K+1, K, K, C),
    strides=(Z.strides[0], Z.strides[1], Z.strides[2], Z.strides[1], Z.strides[2], Z.strides[3])
  )
  mul_size = K * K * C
  out = (A.reshape(-1, mul_size) @ weight.reshape(mul_size, -1)).reshape(N, H-K+1, W-K+1, OUTC)
  return out


In [73]:
out_conv_im2col = conv_im2col(Z, W)
out_ref = conv_reference(Z, W)

In [74]:
print(np.linalg.norm(out_ref - out_conv_im2col))

2.3670215887952075e-14
