# Im2Col Implementation

In [25]:
import math
import numpy as np

In [26]:
class Array:
    def __init__(self, N, C, H, W):
        self.arr = np.arange(N * C * H * W) # 4D array is strored as 1D array in NCHW format
        self.H = H
        self.W = W
        self.N = N
        self.C = C
    
    def get_idx(self, n, c, h, w):
        '''
        Map 4D index to 1D
        '''
        idx = n * (self.W * self.H * self.C) + c * (self.W * self.H) + h * (self.W) + w
        return idx
    
    def print_array(self):
        '''
        Print array as matrices
        '''
        print(self.arr)
        _str = ""
        for n in range(0, self.N):
            print("Image {}".format(n))
            for c in range(0, self.C):
                print("Channel {}".format(c))
                _str = ""
                for h in range(0, self.H):
                    for w in range(0, self.W):
                        idx = self.get_idx(n, c, h, w)
                        _str = _str + "{} ".format(self.arr[idx])
                    _str = _str + "\n"
                print(_str)
    
    def get_im2col_idx(self, kernel_h, kernel_w, stride_w=1, stride_h=1):
        '''
        im2col implementation. The output will be flattened im2col matrix
        of corresponding input index
        '''
        output_h = math.floor((self.H - kernel_h) / stride_h) + 1
        output_w = math.floor((self.W - kernel_w) / stride_w) + 1
        # print("Output shape {} X {}".format(output_h, output_w))
        im2col_res = [None] * (kernel_h * kernel_w * self.C * output_h * output_w * self.N)
        col_idx = 0
        for n in range(0, self.N):
            # print("Image {}".format(n))
            for h in range(0, output_h):
                for w in range(0, output_w):
                    _str = ""
                    for c in range(0, self.C):
                        for i in range(0, kernel_h):
                            for j in range(0, kernel_w):
                                idx = (n * (self.C * self.H * self.W)) + (c * (self.H * self.W)) + ((i+h*stride_h) * self.W) + (j + w*stride_w)
                                im2col_res[col_idx] = idx
                                col_idx += 1
                                _str = _str + "{} ".format(idx) #.format(self.arr[idx])
                    _str = _str + "\n"
                    # print(_str)
        return im2col_res

In [27]:
class Im2Col:
    def __init__(self, N, C, H, W, kernel_h, kernel_w, stride_h, stride_w, arr, is_debug=False):
        self.is_debug = is_debug
        self.arr = arr # 1D array input
        self.N = N # number of images in the batch
        self.C = C # number of channels in an image
        self.H = H # height of the image
        self.W = W # width of the image
        self.kernel_h = kernel_h
        self.kernel_w = kernel_w
        self.stride_h = stride_h
        self.stride_w = stride_w
        
        self.output_h = self.calculate_output_dim(input_dim=H, kernel_dim=kernel_h, stride=stride_h) 
        self.output_w = self.calculate_output_dim(input_dim=W, kernel_dim=kernel_w, stride=stride_w)
        
        self.num_output_elements = self.output_h * self.output_w
        self.num_kernel_elements = kernel_h * kernel_w * C
        self.im2col_sz = self.N * self.num_output_elements * self.num_kernel_elements
        
        if self.is_debug:
            print("output shape is {}X{}".format(self.output_h, self.output_w))
            print("im2col shape {}X{}".format(self.num_output_elements, self.num_kernel_elements))
    
    def calculate_output_dim(self, input_dim, kernel_dim, stride):
        return math.floor((input_dim - kernel_dim) / stride) + 1
    
    def get_image_id(self, idx):
        return math.floor(idx / (self.num_output_elements * self.num_kernel_elements))
    
    def get_image_idx(self, idx):
        return idx % (self.num_output_elements * self.num_kernel_elements)
    
    def get_channel_id(self, idx):
        return math.floor(idx / (self.kernel_h * kernel_w)) % self.C
    
    def get_kernel_id(self, idx):
        return math.floor(idx / self.num_kernel_elements)
    
    def get_kernel_idx(self, idx):
        return idx % (self.kernel_h * self.kernel_w)
    
    def get_h(self, kernel_id, kernel_idx):
        relative_h = math.floor(kernel_idx / self.kernel_w)
        kernel_start_idx = math.floor(kernel_id / self.output_w)
        h = self.stride_h * kernel_start_idx + relative_h
        return h
    
    def get_w(self, kernel_id, kernel_idx):
        relative_w = kernel_idx % self.kernel_w
        kernel_start_idx = kernel_id % self.output_w
        w = self.stride_w * kernel_start_idx + relative_w
        return w
    
    def get_im2col_idx(self):
        im2col_idx = np.ones((self.im2col_sz), np.int)
        
        for idx in range(self.im2col_sz):    
            image_id = self.get_image_id(idx)
            image_idx = self.get_image_idx(idx)
            
            channel_id = self.get_channel_id(image_idx)
            
            kernel_id = self.get_kernel_id(image_idx)
            kernel_idx = self.get_kernel_idx(image_idx)
            
            h = self.get_h(kernel_id, kernel_idx)
            w = self.get_w(kernel_id, kernel_idx)
            
            im_idx = image_id * (self.C * self.H * self.W) + channel_id * (self.H * self.W) + h * (self.W) + w
            im2col_idx[idx] = im_idx
            
            if self.is_debug:
                print("idx: {}, image id:{}, image idx: {}, channel id: {}, kernel id:{}, kernel_idx:{}, h:{}, w:{}, im_idx:{}"\
                      .format(idx, image_id, image_idx, channel_id, kernel_id, kernel_idx, h, w, im_idx))
        
        return im2col_idx, (self.N, self.num_output_elements, self.num_kernel_elements)


In [28]:
N = 1
C = 2
H = 4
W = 5
kernel_h = 3
kernel_w = 3
stride_h = 2
stride_w = 2

In [29]:
arr = Array(N=N, C=C, H=H, W=W)
arr.print_array()

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
Image 0
Channel 0
0 1 2 3 4 
5 6 7 8 9 
10 11 12 13 14 
15 16 17 18 19 

Channel 1
20 21 22 23 24 
25 26 27 28 29 
30 31 32 33 34 
35 36 37 38 39 



In [30]:
ref_mapping = arr.get_im2col_idx(kernel_h=kernel_h, kernel_w=kernel_w, stride_w=stride_w, stride_h=stride_h)
ref_mapping = np.array(ref_mapping)
print(ref_mapping)

[ 0  1  2  5  6  7 10 11 12 20 21 22 25 26 27 30 31 32  2  3  4  7  8  9
 12 13 14 22 23 24 27 28 29 32 33 34]


In [31]:
im2col = Im2Col(N=N, C=C, H=H, W=W, kernel_h=kernel_h,\
                kernel_w=kernel_w, stride_h=stride_h, stride_w=stride_w, arr= arr.arr, is_debug= False)
im2col_idx, im2col_shape = im2col.get_im2col_idx()
im2col_idx = np.array(im2col_idx)
print(im2col_idx)

[ 0  1  2  5  6  7 10 11 12 20 21 22 25 26 27 30 31 32  2  3  4  7  8  9
 12 13 14 22 23 24 27 28 29 32 33 34]


In [32]:
print("Is pass?", np.allclose(ref_mapping, im2col_idx))

Is pass? True


In [33]:
print(ref_mapping.reshape(im2col_shape))

[[[ 0  1  2  5  6  7 10 11 12 20 21 22 25 26 27 30 31 32]
  [ 2  3  4  7  8  9 12 13 14 22 23 24 27 28 29 32 33 34]]]


In [34]:
print(im2col_idx.reshape(im2col_shape))

[[[ 0  1  2  5  6  7 10 11 12 20 21 22 25 26 27 30 31 32]
  [ 2  3  4  7  8  9 12 13 14 22 23 24 27 28 29 32 33 34]]]
