# Programming Assignment 1: Convolution and Back-Propagation

**UBC CPEN 455: Deep Learning, 2023 Winter Term 2**

**Created By Renjie Liao**

**Date: Feb. 19, 2024**

---
# Setup

We will use PyTorch to implement this assignment.

In [6]:
# Imports
import pdb
import torch
import numpy as np
from torch import nn
from math import pi
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms, utils

torch.set_default_dtype(torch.float64)

## load MNIST images
B = 5 # batch size
train_set = datasets.MNIST('./data',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())

loader = torch.utils.data.DataLoader(train_set, batch_size=B)

## load a batch of MNIST images as a PyTorch tensor (shape: B x C x H x W)
# B: batch size
# C: number of channels
# H: height of images
# W: width of images
img, label = next(iter(loader)) # img shape: B x C x H x W, label shape: B X 1

## create a random filter (shape: D x C x K x K)
K = 3 # kernel size
P = 1 # padding size
C = 1 # channel size
D = 2 # number of filters
filter = torch.randn(D, C, K, K) # filter shape: D x C x K x K

---
# Q1 [60Pts]: 2D convolution and its gradient




## 1.1 [30Pts]  Implement 2D convolution:

Discrete convolution can be implemented in multiple ways, e.g., matrix multiplication in spatial/Fourier domains.

First, let us take a look at 1D convolution in spatial domain. Suppose we have a 1D signal with $n$ elements $x_1, x_2, \dots, x_n$ and a 1D filter with $m$ weights $h_1, h_2, \dots, h_m$. Note that we typically use filters with odd sizes for the ease of indexing.

The (discrete) convolution with zero-padding and stride 1 is defined as:

\begin{align}
    y = h \ast x = \sum_{i=1}^{n} \sum_{j=1}^{m} h_j x_{i - \lfloor m/2 \rfloor - 1 + j},
\end{align}
where padded values $x_{-\lfloor m/2 \rfloor + 1}, \dots, x_{0}, x_{n+1}, \dots, x_{n - \lfloor m/2 \rfloor - 1 + m}$ are all zeros.

If you forget about the concepts of padding and stride, take a look at [this guide](https://arxiv.org/pdf/1603.07285.pdf) or [these pictures](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md).

In the 1D case, we can illustrate the two matrix multiplication views of spatial convolution as below.

1.   **im2col**:
The key idea is to first extract the spatial windows from the signal $x$ for individual convolutions and then perform convolutions (i.e., dot product with the filter).
If we put each window as a column in a matrix (the right one in RHS below), then we can perform convolution via the following matrix multiplication (N.B.: the products between the filter and individual columns can be done in parallel).

\begin{align}
    y^\top = (h \ast x)^{\top} = \begin{bmatrix}
                h_m & h_{m-1} & \cdots & h_3 & h_2 & h_1
            \end{bmatrix}
            \begin{bmatrix}
                x_{m - \lfloor m/2 \rfloor} & x_{m - \lfloor m/2 \rfloor + 1} & \cdots & x_m & x_{m+1} & \cdots & 0 & 0 \\
                \vdots & \vdots & \cdots & x_{m-1} & x_m & \cdots & \vdots & \vdots \\
                x_1 & x_2 & \cdots & \vdots & x_{m-1} & \cdots  & x_n & 0 \\
                0 & x_1 & \cdots & \vdots & \vdots & \cdots  & x_{n-1} & x_n \\
                \vdots & 0 & \cdots & \vdots & \vdots & \cdots & \vdots & \vdots \\                        
                0 & 0 & \cdots & x_1 & x_2 & \cdots & x_{n - \lfloor m/2 \rfloor+1} & x_{n - \lfloor m/2 \rfloor}
            \end{bmatrix}.
\end{align}

2.   **filter2row**: The key idea is to convert the filter and the signal to a sparse cyclic matrix and a vector respectively.
Then the convolution is simply the matrix multiplication between the filter and the signal.

\begin{align}
        y = h \ast x =
            \begin{bmatrix}
                h_{\lfloor m/2 \rfloor + 1} & h_{\lfloor m/2 \rfloor + 2} & \cdots & h_m & 0 & \cdots & \cdots & \cdots & \cdots & 0 \\
                h_{\lfloor m/2 \rfloor} & h_{\lfloor m/2 \rfloor + 1} & \cdots & h_{m-1} & h_m & 0 & \cdots & \cdots & \cdots & 0 \\
                \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\
                h_1 & h_2 & \cdots & \cdots & \cdots & \cdots & h_m & 0 & \cdots & 0 \\
                0 & h_1 & h_2 & \cdots & \cdots & \cdots & \cdots & h_m & \cdots & 0 \\
                \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\
                0 & 0 & \cdots & \cdots & \cdots & 0 & h_1 & h_2 & \cdots & h_{\lfloor m/2 \rfloor + 2} \\                
                0 & 0 & \cdots & \cdots & & \cdots \cdots & 0 & h_1 & \cdots & h_{\lfloor m/2 \rfloor + 1}
            \end{bmatrix}
            \begin{bmatrix}
                x_1 \\
                x_2 \\
                x_3 \\
                \vdots \\
                x_n
            \end{bmatrix}
\end{align}

**Task**:
Implement the 2D convolution in the spatial domain via matrix multiplication following the above two views: **im2col** and **filter2row**.
The starter code is provided below.
You just need to fill in the missing parts of function ***conv2d_im2col*** and ***conv2d_filter2row***.
If your implementation is correct, the ***unit_test*** will output:

*Your implementation of xxx is correct!*

Otherwise, it will output:

*Your implementation of xxx is wrong!*

**N.B.**: we assume the strides along height and width are the same and the kernel is square

**Hint**: you can reduce the 2D case to 1D and follow the above construction.

In [7]:
## implement the following two functions
def conv2d_im2col(img, filter, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    ### Fill in this function ###
    # Args:
    #   img: images, shape B x C x H x W
    #   filter: filters, shape D x C x K x K
    #   channel_size: number of channels, scalar (C)
    #   num_filters: number of filters, scalar (D)
    #   kernel_size: kernel size, scalar (K)
    #   stride: stride size, scalar
    #   padding: padding size, scalar
    #
    # Returns:
    #   out: convoluted images, shape B x D x H x W


    output_size = (img.shape[2] + 2 * padding - kernel_size) // stride + 1

    cols = F.unfold(img, kernel_size=kernel_size, padding=padding, stride=stride)
    reshaped_filter = filter.reshape(num_filters, kernel_size * kernel_size * channel_size)
    out = reshaped_filter @ cols
    out = out.view(img.shape[0], num_filters, output_size, output_size)
    return out

def conv2d_filter2row(img, filt, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    """
    Filter-to-row (Toeplitz) 2D convolution via matrix multiplication.

    Args:
        img:     (B, C, H, W) tensor
        filt:    (D, C, K, K) tensor     # D=num_filters, C=channel_size
        channel_size: int (C)
        num_filters: int (D)
        kernel_size: int (K)
        stride:  int (S)
        padding: int (P)

    Returns:
        out:     (B, D, H_out, W_out) tensor
    """
    B, C, H, W = img.shape
    D, C2, K, _ = filt.shape

    S, P = stride, padding

    # Compute output spatial dimensions
    H_out = (H + 2 * P - K) // S + 1
    W_out = (W + 2 * P - K) // S + 1

    # Zero-pad the image so we can safely index without boundary checks
    img_padded = F.pad(img, (P, P, P, P))  # (left, right, top, bottom)

    # The big Toeplitz-like matrix M encodes the convolution
    M = torch.zeros((D * H_out * W_out, C * (H + 2 * P) * (W + 2 * P)), dtype=img.dtype)

    # Build M row by row
    for d in range(D):                     # for each filter
        for i in range(H_out):             # output height
            for j in range(W_out):         # output width
                row_index = d * (H_out * W_out) + i * W_out + j
                for c in range(C):         # input channel
                    for p in range(K):     # kernel height
                        for q in range(K): # kernel width
                            h_in = i * S + p
                            w_in = j * S + q
                            col_index = c * (H + 2 * P) * (W + 2 * P) + h_in * (W + 2 * P) + w_in
                            M[row_index, col_index] = filt[d, c, p, q]

    # Apply M to each flattened padded image
    img_flat = img_padded.reshape(B, -1).T          # shape (C*(H+2P)*(W+2P), B)
    out_flat = M @ img_flat                         # shape (D*H_out*W_out, B)
    out = out_flat.T.reshape(B, D, H_out, W_out)    # reshape to 4D output

    return out

def unit_test_conv2d(img, filter, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    # call your implemented "im2col" conv2D
    y_im2col = conv2d_im2col(img, filter, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)

    # ground truth conv2D
    y_gt = F.conv2d(img, filter, stride=stride, padding=padding)

    diff = (y_im2col - y_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of conv2d_im2col is correct!")
    else:
        print("Your implementation of conv2d_im2col is wrong!")

    # call your implemented "im2col" conv2D
    y_filter2row = conv2d_filter2row(img, filter, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)

    diff = (y_filter2row - y_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of conv2d_filter2row is correct!")
    else:
        print("Your implementation of conv2d_filter2row is wrong!")


unit_test_conv2d(img, filter, channel_size=C, num_filters=D, kernel_size=K, stride=1, padding=P)

Your implementation of conv2d_im2col is correct!
Your implementation of conv2d_filter2row is correct!
Your implementation of conv2d_filter2row is correct!


## 1.2 [20Pts] Implement the gradient of 2D convolution

We now turn to the gradient of 2D convolution.
In particular, given a batch of images $x$ with the shape $B \times C \times H \times W$ and filters $h$ with shape $D \times C \times K \times K$, we can view the convolution (zero-padding and stride 1) as a function
\begin{align}
    y = f(h, x),
\end{align}
that would produce an output tensor $y$ with shape $B \times D \times H \times W$.

If we vectorize $x$, $h$, and $y$, then the Jacobian matrix $∇f = [\frac{\partial y}{\partial h}, \frac{\partial y}{\partial x}]$ would be of shape $BDHW \times (DCKK + BCHW)$.
In practice, we almost never need to compute the Jacobian matrix directly as it is unnecessary for back-propagation.
Instead, we often need to compute the product between the transposed Jacobian and a vector (a.k.a. vector-Jacobian product), i.e., ${\frac{\partial y}{\partial h}}^{\top} v$ and ${\frac{\partial y}{\partial x}}^{\top} v$.
For example, the vector $v$ could the gradient of some loss $\ell$ (scalar) w.r.t. the output above (i.e., $\frac{\partial \ell}{\partial y}$).

**Task**:
Given input images $x$, filters $h$, output $y$, and a vector $v$, implement the gradients ${\frac{\partial y}{\partial h}}^{\top} v$ and ${\frac{\partial y}{\partial x}}^{\top} v$ in the function below.

**N.B.**: The function needs to return the gradients in the original shapes, i.e., ${\frac{\partial y}{\partial h}}^{\top} v$ should have the same shape as $h$ ($D \times C \times K \times K$) and ${\frac{\partial y}{\partial x}}^{\top} v$ should have the same shape as $x$ ($B \times C \times H \times W$).

In [10]:
## implement the following functions
def grad_conv2d(img, filter, out, grad_out, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    ### Fill in this function ###
    # Args:
    #   img: images, shape B x C x H x W
    #   filter: filters, shape D x C x K x K
    #   out: convoluted images, shape B x D x H x W
    #   grad_out: gradient w.r.t. output, shape B x D x H x W
    #   channel_size: number of channels, scalar (C)
    #   num_filters: number of filters, scalar (D)
    #   kernel_size: kernel size, scalar (K)
    #   stride: stride size, scalar
    #   padding: padding size, scalar
    #
    # Returns:
    #   grad_img: gradient w.r.t. img, shape B x C x H x W
    #   grad_filer: gradient w.r.t. filter, shape D x C x K x K

    B, C, H_img, W_img = img.shape
    D = num_filters
    K = kernel_size
    S, P = stride, padding

    # 1) Unfold input into columns: (B, C*K*K, L) where L = H_out * W_out
    cols = F.unfold(img, kernel_size=K, padding=P, stride=S)  # (B, CK2, L)
    _, CK2, L = cols.shape

    # 2) Flatten grad_out spatial to (B, D, L)
    go = grad_out.reshape(B, D, L)  # (B, D, L)

    # 3) Gradient w.r.t. filter:
    # (B, D, L) @ (B, L, CK2) -> (B, D, CK2); sum over batch -> (D, CK2)
    grad_filter_mat = (go @ cols.transpose(1, 2)).sum(dim=0)   # (D, CK2)
    grad_filter = grad_filter_mat.view(D, C, K, K)

    # 4) Gradient w.r.t. image:
    # weight matrix W_flat: (D, CK2); need W_flat^T @ go[b] => (CK2, L) for each b
    W_flat = filter.view(D, CK2)            # (D, CK2)
    dcols = (W_flat.T @ go)                 # broadcasting -> (B, CK2, L)

    # fold columns back to image: (B, C*K*K, L) -> (B, C, H, W)
    grad_img = F.fold(dcols, output_size=(H_img, W_img),
                      kernel_size=K, padding=P, stride=S)

    return grad_img, grad_filter

def unit_test_grad_conv2d(img, filter, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    filter.requires_grad = True
    img.requires_grad = True

    ### ground truth conv2D
    img_out = F.conv2d(img, filter, stride=stride, padding=padding)

    # create a random vector v
    v = torch.randn_like(img_out)

    # call your implemented "grad_conv2d" function
    grad_img, grad_filter = grad_conv2d(img, filter, img_out, v, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)

    # compute ground-truth gradients
    grad_img_gt = torch.autograd.grad(img_out, img, grad_outputs=v, retain_graph=True)[0]
    grad_filter_gt = torch.autograd.grad(img_out, filter, grad_outputs=v, retain_graph=True)[0]
    #pdb.set_trace()

    diff = (grad_img - grad_img_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_img is correct!")
    else:
        print("Your implementation of grad_img is wrong!")

    diff = (grad_filter - grad_filter_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_filter is correct!")
    else:
        print("Your implementation of grad_filter is wrong!")

unit_test_grad_conv2d(img, filter, channel_size=C, num_filters=D, kernel_size=K, stride=1, padding=P)

Your implementation of grad_img is correct!
Your implementation of grad_filter is correct!


---
## 1.3 [10Pts]: Implement gradient checking via the finite difference approximation

We verify the correctness of the implementation of gradient operators by calling PyTorch's autograd function.
However, PyTorch's autograd function just calls the gradient operators implemented by the PyTorch team.
How do they verify the correctness of their implementation?

The answer is **finite difference approximation**.
Following the setup in 1.2, given a batch of images $x$ with the shape $B \times C \times H \times W$ and filters $h$ with shape $D \times C \times K \times K$, we have the convolution (zero-padding and stride 1)
\begin{align}
    y = f(h, x).
\end{align}

Again, mentally vectorize $x$ and $y$ would help us understand the math.
Given any vector $v$ with the same shape as $y$, we are interested in computing ${\frac{\partial y}{\partial h}}^{\top} v$ and ${\frac{\partial y}{\partial x}}^{\top} v$.
These two gradients are equivalent to ${\frac{\partial \ell}{\partial h}}$ and ${\frac{\partial \ell}{\partial x}}$ where
\begin{align}
    \ell(h, x) = y^{\top}v = f(h, x)^{\top} v.
\end{align}
Note that here $\ell$ becomes a scalar.
Based on Talyor's theorem, we have
\begin{align}
    d^{\top} \frac{\partial \ell}{\partial h} = \lim_{\epsilon → 0} \frac{\ell(h + \epsilon \cdot d, x) - \ell(h - \epsilon \cdot d, x)}{2 ϵ},
\end{align}
where $d$ could be any direction vector and $ϵ$ is a scalar.
For our purpose, we just need to set $d$ to be the unit vector to compute the per-dimension value of $\frac{\partial \ell}{\partial h}$.
Specifically, if we set $d$ as the $i$-th unit vector $e_i$, i.e., $d[i] = 1$ and $d[j] = 0, \forall j \neq i$, we can then compute
\begin{align}
    \frac{\partial \ell}{\partial h}[i] &= \lim_{\epsilon → 0} \frac{\ell(h + \epsilon \cdot e_i, x) - \ell(h - \epsilon \cdot e_i, x)}{2 ϵ} \\
    & ≈ \frac{\ell(h + \epsilon \cdot e_i, x) - \ell(h - \epsilon \cdot e_i, x)}{2 ϵ}.
\end{align}

**Task**: Implement the finite-difference based gradient checker for ${\frac{\partial \ell}{\partial h}}$ and ${\frac{\partial \ell}{\partial x}}$.


**N.B.**: For efficiency consideration in the unit test, you can use F.conv2d to compute the convolution in your implementation of *grad_checker*. This assignment is to let you understand how to implement finte-difference. But in pratice, if we want to verify our implementation of conv2d, then we should use our conv2d instead of F.conv2d from PyTorch.


In [None]:
## implement the following functions
def grad_checker(img, filter, conv, grad_out, epsilon=1.0e-5, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    ### Fill in this function ###
    # Args:
    #   img: images, shape B x C x H x W
    #   filter: filters, shape D x C x K x K
    #   conv: convolution function
    #   out: convoluted images, shape B x D x H x W
    #   grad_out: gradient w.r.t. output, shape B x D x H x W
    #   channel_size: number of channels, scalar (C)
    #   num_filters: number of filters, scalar (D)
    #   kernel_size: kernel size, scalar (K)
    #   stride: stride size, scalar
    #   padding: padding size, scalar
    #
    # Returns:
    #   grad_img: gradient w.r.t. img, shape B x C x H x W
    #   grad_filer: gradient w.r.t. filter, shape D x C x K x K
    pass


def unit_test_grad_checker(img, filter, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    epsilon = 1.0e-5
    filter.requires_grad = True
    img.requires_grad = True

    ### ground truth conv2D
    img_out = F.conv2d(img, filter, stride=stride, padding=padding)

    # create a random vector v
    v = torch.randn_like(img_out)

    # call your implemented "grad_checker" function
    grad_img, grad_filter = grad_checker(img, filter, F.conv2d, v, epsilon=epsilon, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)

    # compute ground-truth gradients
    grad_img_gt = torch.autograd.grad(img_out, img, grad_outputs=v, retain_graph=True)[0]
    grad_filter_gt = torch.autograd.grad(img_out, filter, grad_outputs=v, retain_graph=True)[0]
    #pdb.set_trace()

    diff = (grad_img - grad_img_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_img is correct!")
    else:
        print("Your implementation of grad_img is wrong!")

    diff = (grad_filter - grad_filter_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_filter is correct!")
    else:
        print("Your implementation of grad_filter is wrong!")

unit_test_grad_checker(img, filter, channel_size=C, num_filters=D, kernel_size=K, stride=1, padding=P)

---
#Q2 [5Pts]: Implement ReLU and its gradient

**Task**: Implement ReLU operator, i.e., $f(x) = max(x, 0)$, and its gradient operator ${\frac{\partial f}{\partial x}}^{\top} v$ for any given tensor $v$ that is of the same shape as $x$.

**N.B.**: For simplicity, we can assume the input $x$ is of shape $B \times C \times H \times W$ as before.

In [None]:
## implement the following functions
def func_relu(x):
    ### Fill in this function ###
    # Args:
    #   x: input, shape B x C x H x W
    #
    # Returns:
    #   y: output, shape B x C x H x W

    pass


def grad_relu(x, y, grad_out):
    ### Fill in this function ###
    # Args:
    #   x: input, shape B x C x H x W
    #   y: output, shape B x C x H x W
    #   grad_out: gradient w.r.t. output y, shape B x D x H x W
    #
    # Returns:
    #   grad_x: gradient w.r.t. x, shape B x C x H x W
    pass


def unit_test_relu(x):
    x.requires_grad = True

    # call your implemented "func_relu" function
    y = func_relu(x)

    # ground truth ReLU
    y_gt = F.relu(x)

    diff = (y - y_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of func_relu is correct!")
    else:
        print("Your implementation of func_relu is wrong!")

    # create a random vector v
    v = torch.randn_like(y)

    # call your implemented "grad_relu" function
    grad_x = grad_relu(x, y, v)

    # compute ground-truth gradients
    grad_x_gt = torch.autograd.grad(y_gt, x, grad_outputs=v, retain_graph=True)[0]

    diff = (grad_x - grad_x_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_relu is correct!")
    else:
        print("Your implementation of grad_relu is wrong!")

unit_test_relu(torch.randn_like(img))

---
#Q3 [20Pts]: Implement Batch-normalization (BN) for convolution and its gradient

Given a batch of input images $x$ with shape $B \times C \times H \times W$, we compute a single mean and a single standard deviation per channel as below,
\begin{align}
    \mu[c] &= \frac{1}{BHW} \sum_{i=1}^{B} \sum_{m=1}^{H} \sum_{n=1}^{W} x[i, c, m, n] \\
    \sigma^2[c] &= \frac{1}{BHW} \sum_{i=1}^{B} \sum_{m=1}^{H} \sum_{n=1}^{W} (x[i, c, m, n] - \mu[c])^2.
\end{align}
Then we perform BN, $y = f(x, \beta, \gamma)$, as,
\begin{align}
    y[i,c,m,n] &= \gamma[c] \frac{x[i,c,m,n] - \mu[c]}{\sqrt{\sigma^2[c] + \epsilon}} + \beta[c],
\end{align}
where $\gamma$ and $\beta$ are learnable parameters are of shape $C$.
$ϵ$ is a constant.

**Task**: For simplicity, we fix the learnable parameters as $\gamma = 1$ and $\beta = 0$.
Implement BN for convolution and its gradient operators ${\frac{\partial f}{\partial x}}^{\top} v$ for any $v$ that is compatible with the matrix multiplication.



In [None]:
## implement the following functions
def func_batch_norm(x, epsilon=1.0e-5):
    ### Fill in this function ###
    # Args:
    #   x: input, shape B x C x H x W
    #   epsilon: constant, scalar
    #
    # Returns:
    #   y: output, shape B x C x H x W
    pass


def grad_batch_norm(x, y, grad_out, epsilon=1.0e-5):
    ### Fill in this function ###
    # Args:
    #   x: input, shape B x C x H x W
    #   y: output, shape B x C x H x W
    #   grad_out: gradient w.r.t. output y, shape B x D x H x W
    #   epsilon: constant, scalar
    #
    # Returns:
    #   grad_x: gradient w.r.t. x, shape B x C x H x W
    pass


def unit_test_batch_norm(x):
    x.requires_grad = True
    epsilon = 1e-5

    # call your implemented "func_batch_norm" function
    y = func_batch_norm(x, epsilon=epsilon)

    # ground truth ReLU
    BN_gt = nn.BatchNorm2d(x.shape[1], eps=epsilon, momentum=1.0, affine=False, track_running_stats=False)
    y_gt = BN_gt(x)

    diff = (y - y_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of func_batch_norm is correct!")
    else:
        print("Your implementation of func_batch_norm is wrong!")

    # create a random vector v
    v = torch.randn_like(y)

    # call your implemented "grad_batch_norm" function
    grad_x = grad_batch_norm(x, y, v, epsilon=epsilon)

    # compute ground-truth gradients
    grad_x_gt = torch.autograd.grad(y_gt, x, grad_outputs=v, retain_graph=True)[0]

    diff = (grad_x - grad_x_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_batch_norm is correct!")
    else:
        print("Your implementation of grad_batch_norm is wrong!")

unit_test_batch_norm(torch.randn_like(img))

---
#Q4 [15Pts]: Implement a simple CNN and back-propagation (BP)

Now we are ready to build a deep CNN and learn it with back-propagation.
In particular, let us build a simple CNN with following architecture:

Conv $→$ BN $→$ ReLU $→$ Conv $→$ BN $→$ ReLU $→$ Linear.

Here, for all layers, the convolutions are the same as before (i.e., kernel size $3 \times 3$, zero-padding, number of filters $D = 2$, and stride 1), the BNs are without learnable $\gamma$ and $\beta$, and the last linear layer would map whatever input dimension to $10$ classes in MNIST.

**Task**: Implement the above CNN, compute the cross-entropy loss, and compute gradient of the loss w.r.t. filter weights.

**N.B.**: You can use F.cross_entropy provided by PyTorch. But for other operators like Conv, BN, and ReLU and their gradietns, you should use your previous implementations.

In [None]:
## implement the following two functions
def CNN(img, filter_1, filter_2, weight, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    ### Fill in this function ###
    # Args:
    #   img: images, shape B x C x H x W
    #   filter_1: filters at 1st layer, shape D x C x K x K
    #   filter_2: filters at 2nd layer, shape D x C x K x K
    #   weight: weights of linear readout layer, shape ? x 10
    #   channel_size: number of channels, scalar (C)
    #   num_filters: number of filters, scalar (D)
    #   kernel_size: kernel size, scalar (K)
    #   stride: stride size, scalar
    #   padding: padding size, scalar
    #
    # Returns:
    #   out: logits, shape B x 10

    ### 1st layer


    ### 2nd layer


    ### linear readout

    pass


def grad_CNN(img, filter_1, filter_2, weight, grad_loss, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    ### Fill in this function ###
    # Args:
    #   img: images, shape B x C x H x W
    #   filter_1: filters at 1st layer, shape D x C x K x K
    #   filter_2: filters at 2nd layer, shape D x C x K x K
    #   weight: weights of linear readout layer, shape ? x 10
    #   grad_loss: gradient of loss w.r.t. logits, shape B x 10
    #   channel_size: number of channels, scalar (C)
    #   num_filters: number of filters, scalar (D)
    #   kernel_size: kernel size, scalar (K)
    #   stride: stride size, scalar
    #   padding: padding size, scalar
    #
    # Returns:
    #   grad_filter_1: filters, shape D x C x K x K
    #   grad_filter_2: filters, shape D x C x K x K
    #   grad_weight: weight, shape ? x 10

    ### 1st layer


    ### 2nd layer


    ### linear readout

    pass


def unit_test_CNN(img, label, filter_1, filter_2, weight, channel_size=1, num_filters=1, kernel_size=3, stride=1, padding=1):
    # call your implemented "CNN"
    img.requires_grad_()
    filter_1.requires_grad_()
    filter_2.requires_grad_()
    weight.requires_grad_()
    y = CNN(img, filter_1, filter_2, weight, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)
    y.requires_grad_()

    # compute loss function
    loss = F.cross_entropy(y, label).mean()
    loss.requires_grad_()

    # compute gradient of loss w.r.t. logits
    grad_loss = torch.autograd.grad(loss, y, retain_graph=True)[0]

    # call your implemented "grad_batch_norm" function
    grad_filter_1, grad_filter_2, grad_weight = grad_CNN(img, filter_1, filter_2, weight, grad_loss, channel_size=channel_size, num_filters=num_filters, kernel_size=kernel_size, stride=stride, padding=padding)

    # compute ground-truth gradients
    grad_filter_1_gt = torch.autograd.grad(loss, filter_1, retain_graph=True)[0]
    grad_filter_2_gt = torch.autograd.grad(loss, filter_2, retain_graph=True)[0]
    grad_weight_gt = torch.autograd.grad(loss, weight, retain_graph=True)[0]

    diff = (grad_filter_1 - grad_filter_1_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_filter_1 is correct!")
    else:
        print("Your implementation of grad_filter_1 is wrong!")

    diff = (grad_filter_2 - grad_filter_2_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_filter_2 is correct!")
    else:
        print("Your implementation of grad_filter_2 is wrong!")

    diff = (grad_weight - grad_weight_gt).norm()
    if diff < 1.0e-5:
        print("Your implementation of grad_weight is correct!")
    else:
        print("Your implementation of grad_weight is wrong!")


filter_1 = torch.randn(D, C, K, K) # filter shape: D x C x K x K
filter_2 = torch.randn(D, D, K, K) # filter shape: D x C x K x K

### compute the correct shape and then replace None with it ###
weight = torch.randn(None, 10) # weight of the last linear layer

unit_test_CNN(img, label, filter_1, filter_2, weight, channel_size=C, num_filters=D, kernel_size=K, stride=1, padding=P)