<a href="https://colab.research.google.com/github/danielpaulMBRDI/danielpaulMBRDI/blob/main/3_basics_cnn_architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
from torchsummary import summary

## Utility Function

In [None]:
def plot_convs(image, conv_layer, axis=False):
    """Plot convs with matplotlib. Sorry for this lazy code :D"""
    filtered_image = conv_layer(image[None, :])
    n = filtered_image.shape[1]
    if n == 1:
        fig, (ax1, ax2) = plt.subplots(figsize=(8, 4), ncols=2)
        # image = image.permute(1, 2, 0)
        image = image / 2 + 0.5     # unnormalize
        npimg = image.numpy()
        ax1.imshow(np.transpose(npimg, (1, 2, 0)))
        ax1.set_title("Original")
        ax2.imshow(filtered_image.detach().squeeze())  
        ax2.set_title("Filter 1")
        ax1.grid(False)
        ax2.grid(False)
        if not axis:
            ax1.axis(False)
            ax2.axis(False)
        plt.tight_layout();
    elif n == 2:
        filtered_image_1 = filtered_image[:,0,:,:]
        filtered_image_2 = filtered_image[:,1,:,:]
        fig, (ax1, ax2, ax3) = plt.subplots(figsize=(10, 4), ncols=3)
        image = image / 2 + 0.5     # unnormalize
        npimg = image.numpy()
        ax1.imshow(np.transpose(npimg, (1, 2, 0)))
        ax1.set_title("Original")
        ax2.imshow(filtered_image_1.detach().squeeze())  
        ax2.set_title("Filter 1")
        ax3.imshow(filtered_image_2.detach().squeeze())  
        ax3.set_title("Filter 2")
        ax1.grid(False)
        ax2.grid(False)
        ax3.grid(False)
        if not axis:
            ax1.axis(False)
            ax2.axis(False)
            ax3.axis(False)
        plt.tight_layout();
    elif n == 3:
        filtered_image_1 = filtered_image[:,0,:,:]
        filtered_image_2 = filtered_image[:,1,:,:]
        filtered_image_3 = filtered_image[:,2,:,:]
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(figsize=(12, 4), ncols=4)
        image = image / 2 + 0.5     # unnormalize
        npimg = image.numpy()
        ax1.imshow(np.transpose(npimg, (1, 2, 0)))
        ax1.set_title("Original")
        ax2.imshow(filtered_image_1.detach().squeeze())  
        ax2.set_title("Filter 1")
        ax3.imshow(filtered_image_2.detach().squeeze())  
        ax3.set_title("Filter 2")
        ax4.imshow(filtered_image_3.detach().squeeze())  
        ax4.set_title("Filter 3")
        ax1.grid(False)
        ax2.grid(False)
        ax3.grid(False)
        ax4.grid(False)
        if not axis:
            ax1.axis(False)
            ax2.axis(False)
            ax3.axis(False)
            ax4.axis(False)
        plt.tight_layout();

## 0. Basics
- Input of image-format data is usually 4-D array
<br> **(num_instance, width, height, depth)** </br>
    - **num_instance:** number of data instances. Usually designated as **None** to accomodate fluctuating data size
    - **width:** width of an image
    - **height:** height of an image
    - **depth:** depth of an image. Color images are usually with depth = 3 (3 channels for RGB). Black/white images are usually with depth = 1 (only one channel)
    
<img src="http://xrds.acm.org/blog/wp-content/uploads/2016/06/Figure1.png" style="width: 400px"/>

In [None]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True)
print(len(trainset))
print(trainset.__getitem__(0)[0].size)

In [None]:
# showing figures
fig = plt.figure(figsize = (10, 10))
for i in range(9):
  fig.add_subplot(3, 3, i+1)
  plt.imshow(trainset.__getitem__(i)[0])

plt.show()

In [None]:
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True)
print(len(trainset))
print(trainset.__getitem__(0)[0].size)

In [None]:
# showing figures
fig = plt.figure(figsize = (10, 10))
for i in range(9):
  fig.add_subplot(3, 3, i+1)
  plt.imshow(trainset.__getitem__(i)[0])

plt.show()

## 1. FIlter/kernels
Number of filters can be designated

Number of filters equals to the depth of next layer

In PyTorch, convolutional layers are defined as torch.nn.Conv2d, there are 5 important arguments we need to know:

1. in_channels: how many features are we passing in. Our features are our colour bands, in greyscale, we have 1 feature, in colour, we have 3 channels.
2. out_channels: how many kernels do we want to use. Analogous to the number of hidden nodes in a hidden layer of a fully connected network.
3. kernel_size: the size of the kernel. Above we were using 3x3. Common sizes are 3x3, 5x5, 7x7.
4. stride: the "step-size" of the kernel.
5. padding: the number of pixels we should pad to the outside of the image so we can get edge pixels.

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True,  transform=transform)

In [None]:
# 1 kernel of (3,3)
image = torch.tensor(np.array(trainset.__getitem__(0)[0]))
# torch.nn.Conv2d(3, number of kernels, kernel_size=(5, 5))
conv_layer = torch.nn.Conv2d(3, 1, kernel_size=(5, 5))
plot_convs(image, conv_layer)
print(image.shape)

In [None]:
# 2 kernels of (3,3)
# Define for 2 kernels
conv_layer = torch.nn.Conv2d()
plot_convs(image, conv_layer)

In [None]:
# 3 kernels of (5,5)
# Define for 3 kernels
conv_layer = torch.nn.Conv2d()
plot_convs(image, conv_layer)

## 2. Strides

In [None]:
# 1 kernel of (5,5) with stride of 2
conv_layer = torch.nn.Conv2d(3, 1, kernel_size=(5, 5), stride=2)
plot_convs(image, conv_layer, axis=True)

In [None]:
# 1 kernel of (5,5) with stride of 3
conv_layer = torch.nn.Conv2d()
plot_convs(image, conv_layer, axis=True)

In [None]:
# 1 kernel of (5,5) with stride of 4
conv_layer = torch.nn.Conv2d()
plot_convs(image, conv_layer, axis=True)

What's important with CNNs is that the size of our input data does not impact how many parameters we have in our convolutonal layers. For example, your kernels don't care how big your image is (i.e., 28 x 28 or 256 x 256), all that matters is:

* How many features ("channels") you have: in_channels
* How many filters you use in each layer: out_channels
* How big the filters are: kernel_size

## 3. Padding
Zero padding can be applied in convolution/pooling as we have seen above. But custom padding can be applied as well
* nn.ConstandPad1d(padding, value): apply constant padding on 1D data
  * padding: the shape of padding (if tuple, (padingLeft, padingRight))
  * value: the value of padding
* nn.ConstantPad2d(padding, value): apply constant padding on 2D data
  * padding: the shape of padding (if tuple, (padingLeft, padingRight, paddingTop, padingBottom))
  * value: the value of padding
* nn.ZeroPad2d(padding): apply zero padding on 2D data
  * padding: the shape of padding (if tuple, (padingLeft, padingRight, paddingTop, padingBottom))

In [None]:
# p = nn.ConstantPad1d((l, r, t, b), val) 
p = nn.ConstantPad1d((1, 1, 1, 1), -1) # 1d padding with constant 0.75
x = torch.ones(1, 1, 3)
print(p(x))

In [None]:
p = nn.ZeroPad2d((1,0,0,0))  # apply zero padding only on the left of first column
x = torch.ones(1, 1, 3, 3)
print(p(x))

If we use a kernel with no padding, our output image will be smaller as we noted earlier.

In [None]:
# 1 kernel of (51,51)
conv_layer = torch.nn.Conv2d(3, 1, kernel_size=(5, 5))
plot_convs(image, conv_layer, axis=True)

As we saw, we can add padding to the outside of the image to avoid this:

In [None]:
# 1 kernel of (51,51) with padding
conv_layer = torch.nn.Conv2d(3, 1, kernel_size=(5, 5), padding=2)
plot_convs(image, conv_layer, axis=True)

Setting $padding = kernel\_size // 2$ will always result in an output the same shape as the input. Think about why this is...

## 4. Pooling

Pooling is how we can reduce the number of parameters we get out of a torch.nn.Flatten(). It's pretty simple, we just aggregate the data, usually using the maximum or average of a window of pixels. 

We use "pooling layers" to reduce the shape of our image as it's passing through the network. So when we eventually torch.nn.Flatten(), we'll have less features in that flattened layer! We can implement pooling with torch.nn.MaxPool2d()

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 2)),
            torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=(3, 3), padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 2)),
            torch.nn.Flatten(),
            torch.nn.Linear(1250, 1)
        )

    def forward(self, x):
        out = self.main(x)
        return out

In [None]:
model = CNN()
summary(model, (1, 100, 100))

## 5. Flattening

torch.nn.Flatten()

To be connected to fully connected layer (dense layer), convolutional/pooling layer should be "flattened"

Resulting shape = (Number of instances, width X height X depth)

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=(3, 3), padding=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(20000, 1)
        )

    def forward(self, x):
        out = self.main(x)
        return out

In [None]:
model = CNN()
summary(model, (1, 100, 100))