In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

import os
import sys
root = os.path.abspath(os.curdir)
sys.path.append(root)

device = 'cuda:0'

In [2]:
class my_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        '''
        Custom convolution layer which manually performs nn.Conv2d. 
        See torch documentation for reference

        NOTE:
        Assumes stride is 1 and padding is zero
        Assumes kernel is square and odd
        '''
        super(my_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        # initialize weight and bias 
        self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
        else:
            self.bias = None

        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Copied directly from torch implementation:
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        '''
        Performs 2d convolution where x is input
        '''
        # build output tensor
        out = nn.functional.conv2d(x,self.weight,self.bias)
        # h = x.shape[-1] - self.kernel_size + 1
        # w = x.shape[-2] - self.kernel_size + 1
        # bs = x.shape[0]
        # out = torch.empty(bs,self.out_channels,h,w).to(self.weight.device)

        # # iterate 
        # for i_img in range(bs):
        #     for out_ch in range(self.out_channels):
        #         for i in range(w):
        #             for j in range(h):
        #                 window = x[i_img,:,i:i+self.kernel_size,j:j+self.kernel_size]
        #                 out[i_img,out_ch,i,j] = torch.mul(window,self.weight[out_ch]).sum()
        return out



In [3]:
def my_conv_func(image, kernel, device):
    '''
    perform 2d convolution between image and kernel

    NOTE:
    Assumes stride is 1 and padding is zero
    Assumes kernel is square and odd
    '''
    k_size = kernel.shape[-1]
    h = image.shape[-1] - k_size + 1
    w = image.shape[-2] - k_size + 1
    out = torch.zeros(image.shape[0],kernel.shape[0],h,w).to(device)
    # print('h
    temp = 0
    for i_img in range(image.shape[0]):
        for out_ch in range(kernel.shape[0]):
            for i in range(w):
                for j in range(h):
                    window = image[i_img,:,i:i+k_size,j:j+k_size]
                    # print('window shape: ',window.shape)
                    # print('kernel shape: ',kernel[out_ch].shape)
                    # print('product shape: ',torch.mul(window,kernel[out_ch]).shape)
                    # out[i_img,out_ch,i,j] = torch.mul(window,kernel[out_ch]).sum(axis=-1).sum(axis=-1).sum()
                    out[i_img,out_ch,i,j] = torch.mul(window,kernel[out_ch]).sum()
    return out





In [4]:
w,h = 28,28
kernel_dim = 3
bs = 128
in_ch = 3
out_ch = 3

image = torch.randn(bs,in_ch,w,h).to(device)
kernel = torch.randn(out_ch,in_ch,kernel_dim,kernel_dim).to(device)

In [5]:
nn_conv = torch.nn.Conv2d(in_ch,out_ch,kernel_dim,bias=True).to(device)

conv = my_conv(in_ch,out_ch,kernel_dim,bias=True).to(device)
conv.weight = nn_conv.weight
conv.bias = nn_conv.bias

In [6]:
nn_out = nn_conv(image)
out = conv(image)

In [7]:
print(nn_out)

tensor([[[[ 0.4208, -0.0754,  0.7706,  ...,  0.1006, -0.6059, -0.4483],
          [ 0.3534,  0.2688, -0.9063,  ..., -0.7834, -0.3541, -0.5853],
          [-0.6435,  0.0699, -0.6839,  ...,  0.4528, -0.2854, -0.2596],
          ...,
          [ 0.9964, -0.0986,  0.2047,  ..., -0.7568, -0.1168,  0.6002],
          [ 0.2317, -0.8512, -0.3799,  ...,  0.4066,  0.0589,  0.6695],
          [-0.6516,  0.0934,  0.2734,  ..., -0.6227,  0.2407, -0.3211]],

         [[-0.3042, -0.2758, -0.1975,  ...,  1.5376,  0.6500,  0.4462],
          [ 1.0039, -0.5616,  0.5842,  ...,  0.2520,  0.4768,  0.4664],
          [ 0.7620, -0.1323, -0.1765,  ..., -0.8102,  0.9546,  0.1983],
          ...,
          [-0.5213,  0.4124,  0.4536,  ..., -0.0458, -0.1512, -0.8351],
          [-0.1682, -0.1358, -0.1503,  ...,  0.3268,  0.1639, -0.5658],
          [ 0.1460, -0.2299, -0.2249,  ...,  0.3312, -0.4331, -0.8430]],

         [[ 0.3673,  0.0788, -0.0543,  ...,  0.1888,  0.0589,  0.4586],
          [ 0.7396, -0.7322,  

In [8]:
print(out)

tensor([[[[ 0.4208, -0.0754,  0.7706,  ...,  0.1006, -0.6059, -0.4483],
          [ 0.3534,  0.2688, -0.9063,  ..., -0.7834, -0.3541, -0.5853],
          [-0.6435,  0.0699, -0.6839,  ...,  0.4528, -0.2854, -0.2596],
          ...,
          [ 0.9964, -0.0986,  0.2047,  ..., -0.7568, -0.1168,  0.6002],
          [ 0.2317, -0.8512, -0.3799,  ...,  0.4066,  0.0589,  0.6695],
          [-0.6516,  0.0934,  0.2734,  ..., -0.6227,  0.2407, -0.3211]],

         [[-0.3042, -0.2758, -0.1975,  ...,  1.5376,  0.6500,  0.4462],
          [ 1.0039, -0.5616,  0.5842,  ...,  0.2520,  0.4768,  0.4664],
          [ 0.7620, -0.1323, -0.1765,  ..., -0.8102,  0.9546,  0.1983],
          ...,
          [-0.5213,  0.4124,  0.4536,  ..., -0.0458, -0.1512, -0.8351],
          [-0.1682, -0.1358, -0.1503,  ...,  0.3268,  0.1639, -0.5658],
          [ 0.1460, -0.2299, -0.2249,  ...,  0.3312, -0.4331, -0.8430]],

         [[ 0.3673,  0.0788, -0.0543,  ...,  0.1888,  0.0589,  0.4586],
          [ 0.7396, -0.7322,  

In [9]:
print(nn_out.shape)
print(out.shape)
print(((nn_out - out)<0.00001).sum() == np.prod(nn_out.shape))
# print(((out - my_out)))

torch.Size([128, 3, 26, 26])
torch.Size([128, 3, 26, 26])
tensor(True, device='cuda:0')
