 # Implementing Convolutional layer from scratch and with pytorch

In [2]:
import torch
import torch.nn as nn

## Conv operation from scratch -  a basic implementation of concept

In [9]:
def conv2d(X, K):
  # not considering padding and stride
  h, w = K.shape
  output = torch.zeros((X.shape[0]-h+1, X.shape[1]- w + 1))

  for i in range(output.shape[0]):
    for j in range(output.shape[1]):
      output[i, j] = (X[i:i+h, j:j+w] * K).sum()
  return output

In [10]:
X = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
K = torch.Tensor([[0, 1], [2, 3]])
conv2d(X, K)

tensor([[19., 25.],
        [37., 43.]])

## Creating a Conv layer

In [11]:
class ConvLayer(nn.Module):
  def __init__(self,kernel_size, **kwargs):
    super(ConvLayer).__init__(**kwargs)
    self.w = torch.randn(kernel_size, dtype=torch.float32, requires_grad=True).normal_(mean=0, std=0.01)
    self.bias = torch.zeros((1,),dtype=torch.float32,requires_grad=True)

  def forward(self, x):
    return conv2d(x, self.w) + self.bias

In [15]:
X = torch.ones((6, 8))
k = torch.Tensor([[1, 2]])

In [16]:
conv2d(X, k)

tensor([[3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.]])

# Conv layer with padding

In [18]:
# convenience function to calculate the convolutional layer. This
# function  performs corresponding dimensionality elevations and reductions on
# the input and output
def comp_conv2d(conv2d, X):
  # (1,1) indicates that the batch size and the number of channels
  X = X.reshape((1, 1) + X.shape)
  Y = conv2d(X)
  # Exclude the first two dimensions that do not interest us: batch and
  # channel
  return Y.reshape(Y.shape[2:])
  # Note that here 1 row or column is padded on either side, so a total of 2
  # rows or columns are added


In [19]:
conv2d = nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,padding=1)
X = torch.rand(size=(8, 8))
comp_conv2d(conv2d, X).shape

torch.Size([8, 8])