In [1]:
import sys
sys.path.append("..")

In [2]:
def get_conv_outsize(input_size, kernel_size, stride, pad):
    return (input_size + 2 * pad - kernel_size) // stride + 1

In [3]:
H, W = 4, 4 
KH, KW = 3, 3
SH, SW = 1, 1
PH, PW = 1, 1
print(get_conv_outsize(H, KH, SH, PH))
print(get_conv_outsize(W, KW, SW, PW))

4
4


In [4]:
import numpy as np
from mytorch.utils import pair

img = np.ones((1, 1, 3, 3))
kernel_size, stride, pad = 3, 1, 1
to_matrix = True

# img2col

B, C, H, W = img.shape
KH, KW = pair(kernel_size)
SH, SW = pair(stride)
PH, PW = pair(pad)

OH = get_conv_outsize(H, KH, SH, PH)
OW = get_conv_outsize(W, KW, SW, PW)


pad_img = np.pad(img, ((0, 0), (0, 0), (PH, PH + SH - 1), (PW, PW + SW - 1)))

col = np.zeros((B, C, KH, KW, OH, OW))

for i in range(KH):
    i_lim = i + SH * OH
    for j in range(KW):
        j_lim = j + SW * OW
        col[:, :, i, j, :, :] = pad_img[:, :, i:i_lim:SH, j:j_lim:SW]

if to_matrix:
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape((B * OH * OW, -1))

col

array([[0., 0., 0., 0., 1., 1., 0., 1., 1.],
       [0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 1., 1., 0., 1., 1., 0.],
       [0., 1., 1., 0., 1., 1., 0., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 0., 1., 1., 0., 1., 1., 0.],
       [0., 1., 1., 0., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 0., 1., 1., 0., 0., 0., 0.]])

In [5]:
import numpy as np
from mytorch.utils import pair

img_shape, kernel_size, stride, pad = (1, 1, 3, 3), 3, 1, 1
to_matrix = True

N, C, H, W = img_shape
KH, KW = pair(kernel_size)
SH, SW = pair(stride)
PH, PW = pair(pad)
OH = get_conv_outsize(H, KH, SH, PH)
OW = get_conv_outsize(W, KW, SW, PW)

if to_matrix:
    col = col.reshape(N, OH, OW, C, KH, KW).transpose(0, 3, 4, 5, 1, 2)

img = np.zeros((N, C, H + PH * 2 + SH - 1, W + PW * 2 + SW - 1))

for i in range(KH):
    i_lim = i + SH*OH
    for j in range(KW):
        j_lim = j + SW * OW
        img[:, :, i:i_lim:SH, j:j_lim:SW] += col[:, :, i, j, :, :]
img[:, :, PH:PH+H, PW: PW+W]

array([[[[4., 6., 4.],
         [6., 9., 6.],
         [4., 6., 4.]]]])

In [6]:
def im2col(img, kernel_size, stride, pad, to_matrix=True):
    B, C, H, W = img.shape
    KH, KW = pair(kernel_size)
    SH, SW = pair(stride)
    PH, PW = pair(pad)

    OH = get_conv_outsize(H, KH, SH, PH)
    OW = get_conv_outsize(W, KW, SW, PW)


    pad_img = np.pad(img, ((0, 0), (0, 0), (PH, PH + SH - 1), (PW, PW + SW - 1)))

    col = np.zeros((B, C, KH, KW, OH, OW))

    for i in range(KH):
        i_lim = i + SH * OH
        for j in range(KW):
            j_lim = j + SW * OW
            col[:, :, i, j, :, :] = pad_img[:, :, i:i_lim:SH, j:j_lim:SW]

    if to_matrix:
        col = col.transpose(0, 4, 5, 1, 2, 3).reshape((B * OH * OW, -1))
    return col

def col2im(col, image_shape, kernel_size, stride, pad, to_matrix=True):
    
    N, C, H, W = img_shape
    KH, KW = pair(kernel_size)
    SH, SW = pair(stride)
    PH, PW = pair(pad)
    OH = get_conv_outsize(H, KH, SH, PH)
    OW = get_conv_outsize(W, KW, SW, PW)

    if to_matrix:
        col = col.reshape(N, OH, OW, C, KH, KW).transpose(0, 3, 4, 5, 1, 2)

    img = np.zeros((N, C, H + PH * 2 + SH - 1, W + PW * 2 + SW - 1))

    for i in range(KH):
        i_lim = i + SH*OH
        for j in range(KW):
            j_lim = j + SW * OW
            img[:, :, i:i_lim:SH, j:j_lim:SW] += col[:, :, i, j, :, :]
    return img[:, :, PH:PH+H, PW: PW+W]




In [7]:

img = np.ones((1, 1, 3, 3))
col = im2col(img, 3, 1, 1)
re_im = col2im(col, img.shape, 3, 1, 1)
print(img)
print(col)
print(re_im)

[[[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]]]
[[0. 0. 0. 0. 1. 1. 0. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 0. 1. 1. 0.]
 [0. 1. 1. 0. 1. 1. 0. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 0. 1. 1. 0. 1. 1. 0.]
 [0. 1. 1. 0. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 0. 1. 1. 0. 0. 0. 0.]]
[[[[4. 6. 4.]
   [6. 9. 6.]
   [4. 6. 4.]]]]


In [8]:
from mytorch import as_variable, Variable
from mytorch.functions_conv import im2col
import mytorch.functions as F

def conv2d_simple(x, weight, b=None, stride=1, pad=0):
    x, weight = as_variable(x), as_variable(weight)
    
    N, C, H, W = x.shape 
    OC, C, KH, KW = weight.shape
    col = im2col(x, (KH, KW), stride, pad)
    weight = weight.reshape(OC, -1).transpose()
    t = F.linear(col, weight, b)
    y = t.reshape(N, OH, OW, OC).transpose(0, 3, 1, 2)
    return y

In [10]:

from mytorch.functions_conv import conv2d_simple
N, C, H, W = 1, 5, 15, 15
OC, (KH, KW) = 8, (3, 3)

x = Variable(np.random.randn(N,C, H, W))
weight = np.random.randn(OC, C, KH, KW)
y = conv2d_simple(x, weight, stride=1, pad=1)
y.backward()

print(y.shape)
print(x.grad.shape)

(1, 8, 15, 15)
(1, 5, 15, 15)
