In [3]:
import numpy as np
import torch
import torch.nn.functional as F

In [4]:
# original image dimensions
# B * C * H * W
batch = 6
orig_img = np.arange(125*batch).reshape((batch, 5,5,5))
print(orig_img.shape)
print(orig_img)
batch, C, H, W = orig_img.shape

(6, 5, 5, 5)
[[[[  0   1   2   3   4]
   [  5   6   7   8   9]
   [ 10  11  12  13  14]
   [ 15  16  17  18  19]
   [ 20  21  22  23  24]]

  [[ 25  26  27  28  29]
   [ 30  31  32  33  34]
   [ 35  36  37  38  39]
   [ 40  41  42  43  44]
   [ 45  46  47  48  49]]

  [[ 50  51  52  53  54]
   [ 55  56  57  58  59]
   [ 60  61  62  63  64]
   [ 65  66  67  68  69]
   [ 70  71  72  73  74]]

  [[ 75  76  77  78  79]
   [ 80  81  82  83  84]
   [ 85  86  87  88  89]
   [ 90  91  92  93  94]
   [ 95  96  97  98  99]]

  [[100 101 102 103 104]
   [105 106 107 108 109]
   [110 111 112 113 114]
   [115 116 117 118 119]
   [120 121 122 123 124]]]


 [[[125 126 127 128 129]
   [130 131 132 133 134]
   [135 136 137 138 139]
   [140 141 142 143 144]
   [145 146 147 148 149]]

  [[150 151 152 153 154]
   [155 156 157 158 159]
   [160 161 162 163 164]
   [165 166 167 168 169]
   [170 171 172 173 174]]

  [[175 176 177 178 179]
   [180 181 182 183 184]
   [185 186 187 188 189]
   [190 191 192 193 1

In [6]:
# original kernels
# n_f * C * F1 *F2
F1 = 1  # Height of kernels
F2 = 2  # Widht of kernels
n_f = 5   # Number of kernels
orig_kernels = np.random.randn(n_f, C, F1, F2)
print(orig_kernels.shape)
print(orig_kernels)
n_f, C, F1, F2 = orig_kernels.shape

(5, 5, 1, 2)
[[[[ 0.32907632 -0.61277709]]

  [[-0.34693412 -1.2685598 ]]

  [[-2.03336758  0.63359331]]

  [[-0.26326701 -0.14111646]]

  [[-0.21222008 -1.45509832]]]


 [[[ 0.8522614  -0.75966416]]

  [[ 1.06993907  0.97275591]]

  [[ 0.87168064  0.07844427]]

  [[ 0.16401603 -0.63010016]]

  [[-0.84870921 -2.03870041]]]


 [[[-1.46003937 -1.51779083]]

  [[-1.18531413 -0.72428635]]

  [[-1.26425539  0.78176508]]

  [[-0.22022234  1.12377326]]

  [[ 0.43419215  0.02678643]]]


 [[[-0.97716024  1.71098225]]

  [[-0.43053897 -0.01436663]]

  [[-0.68710595 -0.20036684]]

  [[-0.60669235 -1.47378968]]

  [[ 0.16255626 -0.72758056]]]


 [[[ 0.12208474  0.15074406]]

  [[-0.04827451 -0.70138772]]

  [[ 0.88863529  0.86352188]]

  [[-1.98381511 -1.48698009]]

  [[ 0.12243378  0.04054969]]]]


In [7]:
# https://stackoverflow.com/a/40840048
# Refer above link for awesome answer on im2col

def im2col(image,kernel_shape,strides=(1,1)):
    A = image
    B = kernel_shape
    skip = strides

    # Parameters 
    batch, D,M,N = A.shape
    col_extent = N - B[1] + 1
    row_extent = M - B[0] + 1

    # Get batch block indices
    batch_idx = np.arange(batch)[:, None, None] * D * M * N

    # Get Starting block indices
    start_idx = np.arange(B[0])[None, :,None]*N + np.arange(B[1])

    # Generate Depth indeces
    didx=M*N*np.arange(D)
    start_idx=(didx[None, :, None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(row_extent)[None, :, None]*N + np.arange(col_extent)

    # Get all actual indices & index into input array for final output
    act_idx = (batch_idx + 
        start_idx.ravel()[None, :, None] + 
        offset_idx[:,::skip[0],::skip[1]].ravel())

    out = np.take (A, act_idx)
    
    return out

In [8]:
def conv_2D(input, kernel, stride=(1,1), padding=(0,0)):
    
    input = np.float32(input)
    print(input.shape)
    # Padding the image
    P1, P2 = padding
    input = np.pad(input, [(0,0),(0,0),(P1,P1),(P2,P2)])
    print(input.shape)
    S1 ,S2 = stride
    B, C, H, W = input.shape
    N_K, C, K1, K2 = kernel.shape
    
    # Output feature map height and width
    H_ = np.int(np.floor((H - F1) / S1) + 1)
    W_ = np.int(np.floor((W - F2) / S2) + 1)
    
    print(H_,W_)
    
    # im2col
    input = im2col(input, (K1,K2), stride)
    
    input = np.hstack((input))
    
    kernel = kernel.reshape(N_K,-1)
    
    input = np.matmul(kernel, input)
    
    input = np.split(np.array(input), B, axis=1)
    
    print(input)
    
    input = np.array(input).reshape(B,N_K,H_,W_)
    
    return input

In [9]:
out = conv_2D(orig_img, orig_kernels, stride=(2,1), padding=(1,1))
print(out)
print(out.shape)

(6, 5, 5, 5)
(6, 5, 7, 7)
4 6
[array([[   0.        ,    0.        ,    0.        ,    0.        ,
           0.        ,    0.        , -170.34768797, -337.13397468,
        -342.50464552, -347.87531636, -353.24598719, -174.04917823,
        -198.78727163, -390.84068304, -396.21135388, -401.58202471,
        -406.95269555, -199.31630292,    0.        ,    0.        ,
           0.        ,    0.        ,    0.        ,    0.        ],
       [   0.        ,    0.        ,    0.        ,    0.        ,
           0.        ,    0.        , -234.77276369, -228.84129874,
        -229.10937535, -229.37745196, -229.64552857,   16.7454812 ,
        -258.54540908, -231.52206485, -231.79014146, -232.05821807,
        -232.32629468,   37.83736048,    0.        ,    0.        ,
           0.        ,    0.        ,    0.        ,    0.        ],
       [   0.        ,    0.        ,    0.        ,    0.        ,
           0.        ,    0.        ,  106.39397048,   21.66293914,
          17.65

In [10]:
# Comparing with pytorch conv2d 
torch_conv = F.conv2d(torch.Tensor(orig_img), torch.Tensor(orig_kernels),stride=(2,1), padding=(1,1))
print(torch_conv)
print(torch_conv.shape)

tensor([[[[    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,
               0.0000],
          [ -170.3477,  -337.1340,  -342.5046,  -347.8753,  -353.2460,
            -174.0492],
          [ -198.7873,  -390.8407,  -396.2113,  -401.5820,  -406.9527,
            -199.3163],
          [    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,
               0.0000]],

         [[    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,
               0.0000],
          [ -234.7728,  -228.8413,  -229.1094,  -229.3774,  -229.6455,
              16.7455],
          [ -258.5454,  -231.5221,  -231.7901,  -232.0582,  -232.3263,
              37.8374],
          [    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,
               0.0000]],

         [[    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,
               0.0000],
          [  106.3940,    21.6629,    17.6575,    13.6522,     9.6468,
             -99.2038],
          [  103.2964,   -18.3910,   -22.3964,