# Teacher's Assignment No. 14 - Q1

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *12.05.2024* $\qquad$ ***Course:*** *22961 - Deep Learning*

Welcome to question 1 of the fourth assignment of the course *Deep Learning*. \
In this question, we will implement the *SplitLinear* network layer, and make various gradient calculations related to it.

## Imports

First, we will import the required packages for this assignment.
- [pytorch](https://pytorch.org/) - One of the most fundemental and famous tensor handling library.

In [1]:
import torch  # pytorch.
import torch.nn as nn  # neural network module.
import torch.nn.functional as F  # functional module.

## SplitLinear Implementation

We will start with the implementation of the *SplitLinear* layer, using pytorch.

In [17]:
class SplitLinear(nn.Module):
    '''SplitLinear layer.
    
    The SplitLinear layer is a linear layer that splits the input tensor in half, 
    applies a linear transformation to each half, and concatenates the results.
    '''
    def __init__(self, layer_size: int) -> None:
        '''
        Constructor for the SplitLinear layer.

        Args:
            layer_size (int) - Number of features. assumes even.
        '''
        super(SplitLinear, self).__init__()
        self.linear = nn.Linear(layer_size // 2, layer_size // 2)

        # Use Xavier initialization for the weights.
        # Reasoning for use in the video.
        nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass of the layer.

        Args:
            x (torch.Tensor) - Input tensor.
                Assumes shape (batch_size, #features), where #features is even.

        Returns:
            torch.Tensor - Output tensor.
        '''

        # Split the input tensor in half.
        x1, x2 = torch.chunk(x, 2, dim=1)

        # Apply linear transformation to each half.
        x1, x2 = self.linear(x1), self.linear(x2)

        # Concatenate the results and apply ReLU.
        x = F.relu(torch.cat([x1, x2], dim=1))

        return x

In [20]:
# Example if Single pass through the `SplitLinear` layer.
split_linear = SplitLinear(6)

# Random input tensor.
X = torch.randn(2, 6)
print(f"Input:\n{X = }")
print(f"{X.shape = }\n")

# Forward pass (not using `.forward` for printing each stage).
with torch.no_grad():
    X1, X2 = torch.chunk(X, 2, dim=1)
    print(f"Split:\n{X1 = }\n{X2 = }")
    print(f"{X1.shape = }\n{X2.shape = }\n")

    Z1, Z2 = split_linear.linear(X1), split_linear.linear(X2)
    print(f"Linear:\n{Z1 = }\n{Z2 = }")
    print(f"{Z1.shape = }\n{Z2.shape = }\n")

    Y = F.relu(torch.cat([X1, X2], dim=1))
    print(f"Output:\n{Y = }")
    print(f"{Y.shape = }")

Input:
X = tensor([[-1.5434, -1.9621,  1.5148,  0.1254, -0.5788,  0.9155],
        [-1.7521,  0.6839, -1.0127,  1.8485,  0.7364, -0.4879]])
X.shape = torch.Size([2, 6])

Split:
X1 = tensor([[-1.5434, -1.9621,  1.5148],
        [-1.7521,  0.6839, -1.0127]])
X2 = tensor([[ 0.1254, -0.5788,  0.9155],
        [ 1.8485,  0.7364, -0.4879]])
X1.shape = torch.Size([2, 3])
X2.shape = torch.Size([2, 3])

Linear:
Z1 = tensor([[-0.6188, -2.3026,  1.8200],
        [-1.7879, -1.2673,  0.1810]])
Z2 = tensor([[-0.1464, -0.2885,  1.0220],
        [ 0.9029,  1.6032, -0.2288]])
Z1.shape = torch.Size([2, 3])
Z2.shape = torch.Size([2, 3])

Output:
Y = tensor([[0.0000, 0.0000, 1.5148, 0.1254, 0.0000, 0.9155],
        [0.0000, 0.6839, 0.0000, 1.8485, 0.7364, 0.0000]])
Y.shape = torch.Size([2, 6])


## Block diagram

To easily understand the Split Linear layer, we can see the next block diagram that describes it.

<img src="block_diagram_q1.png"></img>

## Analysis of SplitLinear vs. Standard Linear Layer
### Parameters in SplitLinear Layer
- Input size: $m$ (even)
- Output size: $m$
- Weight matrix: $(\frac{m}{2}, \frac{m}{2})$
- Bias vector: $(\frac{m}{2})$ (duplicated)
- Total Parameters: $(\frac{m}{2})^2 + \frac{m}{2}$

### Parameters in Standard Linear Layer
- Weight matrix: $(m, m)$
- Bias vector: $(m)$
- Total Parameters: $m^2 + m$

### Ratio of Parameters
$$
\frac{\#SplitLinear}{\#Linear}
    = \frac{(\frac{m}{2})^2 + \frac{m}{2}}{m^2 + m} 
    = \frac{\frac{m}{4} + \frac{1}{2}}{m + 1} 
    = \frac{1}{4} \cdot \frac{m + \frac{1}{8}}{m + 1}
    \xrightarrow[m \rightarrow \infty]{} \frac{1}{4} 
$$

## Gradient Calculating

$ \def\d{\delta} $
To caluculate the number of parameters in the `SplitLinear` layer, we can use the *chain rule*.

Assuming we have $\frac{\d C}{\d Y}$ already calculated, we get

$$
\frac{\d C}{\d w_{p, q}} = \frac{\d C}{\d Y_p} \cdot \frac{\d Y_p}{\d Z_p} \cdot \frac{\d Z_p}{\d w_{p, q}} 
                         + \frac{\d C}{\d Y_{p + \frac{m}{2}}} \cdot \frac{\d Y_p}{\d Z_p} \cdot \frac{\d Z_p}{\d w_{p, q}} 
$$

