In [41]:
import numpy as np
import torch
import torch.nn.functional as F

In [42]:
# original image dimensions
# B * C * H * W
batch = 6
orig_img = np.arange(125*batch).reshape((batch, 5,5,5))
print(orig_img.shape)
print(orig_img)
batch, C, H, W = orig_img.shape

(6, 5, 5, 5)
[[[[  0   1   2   3   4]
   [  5   6   7   8   9]
   [ 10  11  12  13  14]
   [ 15  16  17  18  19]
   [ 20  21  22  23  24]]

  [[ 25  26  27  28  29]
   [ 30  31  32  33  34]
   [ 35  36  37  38  39]
   [ 40  41  42  43  44]
   [ 45  46  47  48  49]]

  [[ 50  51  52  53  54]
   [ 55  56  57  58  59]
   [ 60  61  62  63  64]
   [ 65  66  67  68  69]
   [ 70  71  72  73  74]]

  [[ 75  76  77  78  79]
   [ 80  81  82  83  84]
   [ 85  86  87  88  89]
   [ 90  91  92  93  94]
   [ 95  96  97  98  99]]

  [[100 101 102 103 104]
   [105 106 107 108 109]
   [110 111 112 113 114]
   [115 116 117 118 119]
   [120 121 122 123 124]]]


 [[[125 126 127 128 129]
   [130 131 132 133 134]
   [135 136 137 138 139]
   [140 141 142 143 144]
   [145 146 147 148 149]]

  [[150 151 152 153 154]
   [155 156 157 158 159]
   [160 161 162 163 164]
   [165 166 167 168 169]
   [170 171 172 173 174]]

  [[175 176 177 178 179]
   [180 181 182 183 184]
   [185 186 187 188 189]
   [190 191 192 193 1

In [43]:
# original kernels
# n_f * C * F1 *F2
F1 = 3  # Height of kernels
F2 = 2  # Widht of kernels
n_f = 5   # Number of kernels
orig_kernels = np.random.randn(n_f, C, F1, F2)
print(orig_kernels.shape)
print(orig_kernels)
n_f, C, F1, F2 = orig_kernels.shape

(5, 5, 3, 2)
[[[[ 0.14454604 -0.65569735]
   [ 0.39279483 -0.40413499]
   [ 0.41816789 -0.79195857]]

  [[ 0.39503583  0.5263565 ]
   [ 1.04714913  1.08851947]
   [ 1.48599317  1.58629343]]

  [[-0.36004606 -0.60996999]
   [-0.29564425 -1.62871241]
   [-0.7150853  -0.49192752]]

  [[-0.3658548   2.03888203]
   [ 2.37461521  0.39199225]
   [ 0.55288528  1.18202265]]

  [[ 1.59821359  1.00940859]
   [-1.0122152  -0.39126753]
   [-0.45003487  0.34923308]]]


 [[[-0.08953909  0.16082477]
   [-1.19758815  2.24671299]
   [ 2.2030779  -0.29198432]]

  [[ 1.42557857  0.56325701]
   [-0.30325771 -0.29576479]
   [-1.72807525 -1.20018773]]

  [[ 0.98930933 -0.63760167]
   [-0.33744598  1.78219793]
   [-0.97123504 -1.34471207]]

  [[-1.58489657  1.08053359]
   [-0.93286205 -0.54430548]
   [-0.43159127 -0.02087982]]

  [[ 0.21265506  1.45868886]
   [-2.00707449 -0.07679159]
   [-0.43510761 -0.14881025]]]


 [[[ 1.41481219  1.176156  ]
   [ 0.27744498 -0.95543133]
   [-1.34305395  0.27957027]]

  [[

In [44]:
# https://stackoverflow.com/a/40840048
# Refer above link for awesome answer on im2col

def im2col(image,kernel_shape,strides=(1,1)):
    A = image
    B = kernel_shape
    skip = strides

    # Parameters 
    batch, D,M,N = A.shape
    col_extent = N - B[1] + 1
    row_extent = M - B[0] + 1

    # Get batch block indices
    batch_idx = np.arange(batch)[:, None, None] * D * M * N

    # Get Starting block indices
    start_idx = np.arange(B[0])[None, :,None]*N + np.arange(B[1])

    # Generate Depth indeces
    didx=M*N*np.arange(D)
    start_idx=(didx[None, :, None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(row_extent)[None, :, None]*N + np.arange(col_extent)

    # Get all actual indices & index into input array for final output
    act_idx = (batch_idx + 
        start_idx.ravel()[None, :, None] + 
        offset_idx[:,::skip[0],::skip[1]].ravel())

    out = np.take (A, act_idx)
    
    return out

In [45]:
def conv_2D(input, kernel, stride=(1,1), padding=(0,0)):
    
    input = np.float32(input)
    
    # Padding the image
    P1, P2 = padding
    input = np.pad(input, [(0,0),(0,0),(P1,P1),(P2,P2)])
    
    S1 ,S2 = stride
    B, C, H, W = input.shape
    N_K, C, K1, K2 = kernel.shape
    
    # Output feature map height and width
    H_ = np.int(np.floor((H - F1 + 2 * P1) / S1) + 1)
    W_ = np.int(np.floor((W - F2 + 2 * P2) / S2) + 1)
    
    # im2col
    input = im2col(input, (K1,K2), stride)
    
    input = np.hstack((input))
    
    kernel = kernel.reshape(N_K,-1)
    
    input = np.matmul(kernels, input)
    
    input = np.split(np.array(input), B, axis=1)
    
    input = np.array(input).reshape(B,N_K,H_,W_)
    
    return input

In [46]:
out = conv_2D(orig_img, orig_kernels, stride=(2,1))
print(out)
print(out.shape)

[[[[ -172.90108029  -174.4097514   -175.9184225   -177.42709361]
   [ -187.98779136  -189.49646247  -191.00513358  -192.51380469]]

  [[ -144.10333096  -142.60464159  -141.10595221  -139.60726284]
   [ -129.11643721  -127.61774783  -126.11905846  -124.62036908]]

  [[ -562.79094926  -570.58369433  -578.3764394   -586.16918447]
   [ -640.71839995  -648.51114502  -656.30389009  -664.09663516]]

  [[ -311.50626797  -318.60964862  -325.71302928  -332.81640993]
   [ -382.5400745   -389.64345515  -396.7468358   -403.85021646]]

  [[  226.73922778   231.41689284   236.09455789   240.77222295]
   [  273.51587834   278.19354339   282.87120845   287.5488735 ]]]


 [[[ -361.48496874  -362.99363985  -364.50231096  -366.01098207]
   [ -376.57167982  -378.08035093  -379.58902203  -381.09769314]]

  [[   43.23284097    44.73153035    46.23021972    47.7289091 ]
   [   58.21973473    59.7184241     61.21711348    62.71580286]]

  [[-1536.88408292 -1544.67682799 -1552.46957306 -1560.26231813]
   [-1614

In [47]:
# Comparing with pytorch conv2d 
torch_conv = F.conv2d(torch.Tensor(orig_img), torch.Tensor(orig_kernels),stride=(2,1))
print(torch_conv)
print(torch_conv.shape)

tensor([[[[  563.8593,   572.2688,   580.6783,   589.0879],
          [  647.9550,   656.3644,   664.7741,   673.1835]],

         [[ -395.9245,  -398.3814,  -400.8383,  -403.2951],
          [ -420.4933,  -422.9501,  -425.4070,  -427.8640]],

         [[ -849.4199,  -862.5690,  -875.7178,  -888.8668],
          [ -980.9092,  -994.0580, -1007.2071, -1020.3559]],

         [[  278.7057,   285.6808,   292.6558,   299.6309],
          [  348.4559,   355.4310,   362.4061,   369.3811]],

         [[  516.5458,   526.0540,   535.5621,   545.0704],
          [  611.6278,   621.1359,   630.6441,   640.1525]]],


        [[[ 1615.0543,  1623.4639,  1631.8733,  1640.2828],
          [ 1699.1500,  1707.5593,  1715.9690,  1724.3787]],

         [[ -703.0339,  -705.4908,  -707.9476,  -710.4046],
          [ -727.6026,  -730.0596,  -732.5164,  -734.9733]],

         [[-2493.0359, -2506.1848, -2519.3337, -2532.4829],
          [-2624.5254, -2637.6741, -2650.8230, -2663.9722]],

         [[ 1150.5837,

In [48]:
type(torch_conv)

torch.Tensor