In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

import os
import sys
root = os.path.abspath(os.curdir)
sys.path.append(root)

device = 'cuda:0'

In [2]:
class my_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        '''
        Custom convolution layer which manually performs nn.Conv2d. 
        See torch documentation for reference

        NOTE:
        Assumes stride is 1 and padding is zero
        Assumes kernel is square and odd
        '''
        super(my_conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        # initialize weight and bias 
        self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
        else:
            self.bias = None

        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Copied directly from torch implementation:
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        '''
        Performs 2d convolution where x is input
        '''
        # build output tensor
        # out = nn.functional.conv2d(x,self.weight,self.bias)
        h = x.shape[-1] - self.kernel_size + 1
        w = x.shape[-2] - self.kernel_size + 1
        bs = x.shape[0]
        out = torch.empty(bs,self.out_channels,h,w).to(self.weight.device)

        # iterate 
        for i_img in range(bs):
            for out_ch in range(self.out_channels):
                for i in range(w):
                    for j in range(h):
                        window = x[i_img,:,i:i+self.kernel_size,j:j+self.kernel_size]
                        out[i_img,out_ch,i,j] = torch.mul(window,self.weight[out_ch]).sum() 
                        out[i_img,out_ch,i,j] += self.bias[out_ch]
        return out



In [3]:
# def my_conv_func(image, kernel, device):
#     '''
#     perform 2d convolution between image and kernel

#     NOTE:
#     Assumes stride is 1 and padding is zero
#     Assumes kernel is square and odd
#     '''
#     k_size = kernel.shape[-1]
#     h = image.shape[-1] - k_size + 1
#     w = image.shape[-2] - k_size + 1
#     out = torch.zeros(image.shape[0],kernel.shape[0],h,w).to(device)
#     # print('h
#     temp = 0
#     for i_img in range(image.shape[0]):
#         for out_ch in range(kernel.shape[0]):
#             for i in range(w):
#                 for j in range(h):
#                     window = image[i_img,:,i:i+k_size,j:j+k_size]
#                     # print('window shape: ',window.shape)
#                     # print('kernel shape: ',kernel[out_ch].shape)
#                     # print('product shape: ',torch.mul(window,kernel[out_ch]).shape)
#                     # out[i_img,out_ch,i,j] = torch.mul(window,kernel[out_ch]).sum(axis=-1).sum(axis=-1).sum()
#                     out[i_img,out_ch,i,j] = torch.mul(window,kernel[out_ch]).sum()
#     return out





In [4]:
w,h = 28,28
kernel_dim = 3
bs = 128
in_ch = 3
out_ch = 3

image = torch.randn(bs,in_ch,w,h).to(device)
kernel = torch.randn(out_ch,in_ch,kernel_dim,kernel_dim).to(device)

In [5]:
nn_conv = torch.nn.Conv2d(in_ch,out_ch,kernel_dim,bias=True).to(device)

conv = my_conv(in_ch,out_ch,kernel_dim,bias=True).to(device)
conv.weight = nn_conv.weight
conv.bias = nn_conv.bias

In [6]:
nn_out = nn_conv(image)
out = conv(image)

In [7]:
print(nn_out)

tensor([[[[ 9.1441e-01,  7.3346e-01,  1.6453e-01,  ...,  5.9262e-01,
            2.0248e-01, -1.5027e-01],
          [ 9.7775e-03,  1.0576e+00,  1.1865e+00,  ...,  8.3206e-02,
            9.0618e-01,  2.0995e-01],
          [-9.4050e-01,  2.8482e-01, -7.2557e-01,  ..., -8.4529e-02,
           -5.2521e-01,  5.7985e-01],
          ...,
          [-6.2867e-01,  2.7249e-01,  3.6793e-02,  ...,  4.2819e-01,
            8.9468e-01,  6.6741e-01],
          [ 6.4750e-01,  2.9265e-01, -1.7791e-01,  ...,  1.1379e+00,
            5.7991e-01,  4.2833e-01],
          [ 1.1297e-01,  4.8767e-01,  7.7206e-02,  ...,  6.7862e-01,
            6.1621e-01, -1.4540e-01]],

         [[-3.3686e-01, -3.5505e-01,  5.6457e-01,  ...,  6.1172e-01,
            3.2040e-01,  7.3732e-01],
          [-1.4169e+00,  1.8984e-01, -1.1344e-01,  ..., -4.3453e-01,
           -1.4161e-01,  3.8977e-01],
          [ 7.9556e-01, -6.0562e-01, -6.4430e-01,  ..., -1.2078e+00,
            8.1462e-01, -8.2979e-02],
          ...,
     

In [8]:
print(out)

tensor([[[[ 9.1441e-01,  7.3346e-01,  1.6453e-01,  ...,  5.9262e-01,
            2.0248e-01, -1.5027e-01],
          [ 9.7776e-03,  1.0576e+00,  1.1865e+00,  ...,  8.3206e-02,
            9.0618e-01,  2.0995e-01],
          [-9.4050e-01,  2.8482e-01, -7.2557e-01,  ..., -8.4529e-02,
           -5.2521e-01,  5.7985e-01],
          ...,
          [-6.2867e-01,  2.7249e-01,  3.6793e-02,  ...,  4.2819e-01,
            8.9468e-01,  6.6741e-01],
          [ 6.4750e-01,  2.9265e-01, -1.7791e-01,  ...,  1.1379e+00,
            5.7991e-01,  4.2833e-01],
          [ 1.1297e-01,  4.8767e-01,  7.7206e-02,  ...,  6.7862e-01,
            6.1621e-01, -1.4540e-01]],

         [[-3.3686e-01, -3.5505e-01,  5.6457e-01,  ...,  6.1172e-01,
            3.2040e-01,  7.3732e-01],
          [-1.4169e+00,  1.8984e-01, -1.1344e-01,  ..., -4.3453e-01,
           -1.4161e-01,  3.8977e-01],
          [ 7.9556e-01, -6.0562e-01, -6.4430e-01,  ..., -1.2078e+00,
            8.1462e-01, -8.2979e-02],
          ...,
     

In [9]:
print(nn_out.shape)
print(out.shape)
print(((nn_out - out)<0.00001).sum() == np.prod(nn_out.shape))
# print(((out - my_out)))

torch.Size([128, 3, 26, 26])
torch.Size([128, 3, 26, 26])
tensor(True, device='cuda:0')
