# CNN

In [2]:
import torch
import torch.nn as nn

## conv 2d

In [12]:
# example 2d

def conv2d(input, kernel, stride=(1, 1), padding=(0, 0)):
    '''
    input: [batch_size, in_channel, height, width]
    kernel: [out_channel, in_channel, kernel_height, kernel_width]
    '''
    # TODO: dilation

    batch_size, in_channel, height, width = input.shape
    out_channel, in_channel, kernel_height, kernel_width = kernel.shape

    # padding
    padded_height = height + 2 * padding[0]
    padded_width = width + 2 * padding[1]
    padded_input = torch.zeros((batch_size, in_channel, padded_height, padded_width))
    padded_input[:, :, padding[0]:padding[0] + height, padding[1]:padding[1] + width] = input

    # output size
    output_height = ((height + 2 * padding[0] - kernel_height) // stride[0]) + 1
    output_width = ((width + 2 * padding[1] - kernel_width) // stride[1]) + 1

    output = torch.zeros((batch_size, out_channel, output_height, output_width))

    for i in range(output_height):
        for j in range(output_width):
            window = padded_input[:, :, i * stride[0]:i * stride[0] + kernel_height, j * stride[1]:j * stride[1] + kernel_width]

            for o in range(out_channel):
                output[:, o, i, j] = (window * kernel[o]).sum(dim=(1, 2, 3))

    return output



In [19]:
batch_size, in_channel, height, width, out_channel, kernel_height, kernel_width = 2, 3, 4, 5, 4, 2, 2
input = torch.rand(batch_size, in_channel, height, width)
kernel = torch.rand(out_channel, in_channel, kernel_height, kernel_width)

In [20]:
res1 = conv2d(input, kernel)
res1.shape

torch.Size([2, 4, 3, 4])

In [21]:
m = nn.Conv2d(in_channel, out_channel, (kernel_height, kernel_width), stride=(1, 1), padding=(0, 0))
res2 = m(input)
res2.shape

torch.Size([2, 4, 3, 4])

## pooling

In [33]:
# example max pooling 2d

def max_pool_2d(input, kernel_size=(2, 2), stride=(1, 1), padding=(0, 0)):
    '''
    input: [batch_size, channel, height, width]
    '''
    batch_size, channel, height, width = input.shape
    kernel_height, kernel_width = kernel_size

    # padding
    padded_height = height + 2 * padding[0]
    padded_width = width + 2 * padding[1]
    padded_input = torch.zeros((batch_size, channel, padded_height, padded_width))
    padded_input[:, :, padding[0]:padding[0] + height, padding[1]:padding[1] + width] = input

    # output size
    output_height = ((height + 2 * padding[0] - kernel_height) // stride[0]) + 1
    output_width = ((width + 2 * padding[1] - kernel_width) // stride[1]) + 1

    output = torch.zeros((batch_size, channel, output_height, output_width))

    for i in range(output_height):
        for j in range(output_width):
            window = padded_input[:, :, i * stride[0]:i * stride[0] + kernel_height, j * stride[1]:j * stride[1] + kernel_width]
            output[:, :, i, j] = window.max(dim=2)[0].max(dim=2)[0]

    return output



In [35]:
res1 = max_pool_2d(input)
res1.shape

torch.Size([2, 3, 3, 4])

In [37]:
m = nn.MaxPool2d((2, 2), stride=(1, 1))
res2 = m(input)
res2.shape

torch.Size([2, 3, 3, 4])

In [38]:
res1 == res2

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]],


        [[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])