# Advanced Topics in Deep Learning

## 1. Expressivity of DNN
- ### 1.1 Preliminaries
    - #### Backpropagation
    - #### Gradient vanishing/explosion
- ### 1.2 Signal propagation in DNN
- ### 1.3 Understand Expressivity through Signal propagation in DNN
    - #### Preliminaries: Riemannian geometry
    - #### Exponential expressivity in chaotic phase
    - #### Forward propagation and expressivity
    - #### Backward propagation and gradient vanishing/explosion
- ### 1.4 Effects of different initializations and NN architectures
    - #### Preliminaries: Jacobian spectrum
    - #### Orthogonal vs Gaussian weight initialization
    - #### Bounded vs unbounded activation functions [WIP]
    - #### Dropout [WIP]
    - #### Batch normalization [WIP]
    - #### Batch size [WIP]

## 2. Optimization
- ### 2.1 Proximal optimization
    - #### Approximation 1 - natural gradient
    - #### Approximation 2 - damped Newton method
- ### 2.2 Gauss-Netwon matrix - approximation to Hessian matrix
    - #### Benefit of using Gauss-Newton matrix
- ### 2.3 Output space gradient descent

## 3. Learnability of DNN
- ### 3.1 Challenge #1: Gradient vanishing/explosion
- ### 3.2 Challenge #2: Proliferation of saddle points
- ### 3.3 Challenge #3: Failure of gradient-based algorithms
- ### 3.4 Challenge #4: Catastrophic interference/forgetting
- ### 3.5 NN tuning best practice [WIP]
    - #### Batch size
    - #### Learning rate
    
## 4. Self-attention as rank lowering operation 
- ### 4.1 Rank collapsing of pure self-attention 
    - #### MLP
    - #### Skip-connection
    - #### Layer normalization
- ### 4.2 Efficient approximation to self-attention mechanism
    - #### Linformer
    
## Appendix
- ### A1 Iterative equations for $q^l$ and $c^l$ [WIP]
- ### A2 Gradient descent based optimizers
    - #### Gradient descent
    - #### Stochastic gradient descent
    - #### Newton's method
    - #### ADAM
- ### A3 Natural gradient descent
    - #### Natural gradient with Fisher Information Matrix (FIM)
    - #### Derivation of FIM using differential geometry
- ### A4 Morse's lemma
    - #### Morse index
    - #### Morse's lemma

## 1 Expressivity of DNN 

In this section we will summarize works on investigating why NN is so powerful (high expressivity) and why NN is difficult to train (due to gradient explosion/vanishing)

### 1.1 Preliminaries

#### Backpropagation

Given an $L$-layer NN:

\begin{align}
&x^{(l)}_i = F\Big(\sum_j W_{ij}^{(l)} x^{(l-1)}_j + b^{(l)}_j \Big) \\
&\hat{y} = F\Big(\sum_j W_j^{(L)} x^{(L-1)}_j + b^{(L)}_j \Big) \\
&loss = \mathcal{L}(\vec{y}, \hat{y}) = \sum_k \mathcal{l}(y_k, \hat{y}_k) \label{NN_flow}\tag{Eq 1.1}
\end{align}

Where the weight $W^{(l)}_{ij}$ connects the node $x^{(l)}_j$ in the $l$-th layer to node $x^{(l+1)}_i$ in the $l+1$-th layer. The loss function is written in the summation form to highlight the possibility of using stochastic methods (e.g. SGD). To update the weights in the first layer of the NN, one need to compute the following gradient:

\begin{align}
\frac{\partial \mathcal{L}}{\partial w^{(1)}_{ij}} &= \frac{\partial \mathcal{L}}{\partial x^{(1)}_{i}}\frac{\partial x^{(1)}_{i}}{\partial w^{(1)}_{ij}}\\
\end{align}

Note that $\mathcal{L}$ is not the direct function of $x^{(0)}$ and so in order to compute the first derivative, one need to use chain rule

\begin{align}
\frac{\partial \mathcal{L}}{\partial w^{(1)}_{ij}} &= \frac{\partial \mathcal{L}}{\partial x^{(1)}_{i}}\frac{\partial x^{(1)}_{i}}{\partial w^{(1)}_{ij}}\\
&= \sum_{k_1,...k_L}\frac{\partial \mathcal{L}}{\partial x^{(L)}_{k_L}}
\frac{\partial x^{(L)}_{k_L}}{\partial x^{(L-1)}_{k_{L-1}}}
\frac{\partial x^{(L-1)}_{k_{L-1}}}{\partial x^{(L-2)}_{k_{L-2}}}...
\frac{\partial x^{(2)}_{k_2}}{\partial x^{(1)}_{k_1}}
\frac{\partial x^{(1)}_{k_1}}{\partial w^{(1)}_{ij}}\\
&= \sum_{k_1,...k_L}\frac{\partial \mathcal{L}}{\partial x^{(L)}_{k_L}}
J^{(L)}_{k_L, k_{L-1}}\ \cdot
J^{(L-1)}_{k_{L-1}, k_{L-2}}\ \cdot...
J^{(2)}_{k_2, k_1}\ 
\frac{\partial x^{(1)}_{k_1}}{\partial w^{(1)}_{ij}} \\
&= \nabla_{x^{(L)}} \mathcal{L} \cdot \prod_{l=1}^L J^{(l)} \cdot \frac{\partial \pmb x^{(1)}}{\partial w^{(1)}_{ij}} \label{chain_rule_1}\tag{Eq 1.2}
\end{align}

where $J^{(l)}_{ij}$ are the elements of the Jacobian matrix at layer $l$. $k_j$ is the number of neurons in layer $j$. The last line in \ref{chain_rule_1} shows that to compute the gradient $l$-layer deep will involves product of $l$ Jacobian. We will see this is the source of gradient vanishing/explosion [ref](https://arxiv.org/pdf/1211.5063.pdf).

#### Gradient vanishing/explosion
To illustrates gradient vanishing/explosion from \ref{chain_rule_1} mathematically, assume the Jacobians are the same. Repeated application of Jacobian results in a stationary solution and the following eigenvalue problem

\begin{equation}
J^n \vec{x} = \vec{x},\ \ \ \ n\to \infty
\end{equation}

Using SVD, $J^n = U\Sigma^n V^T$ is dominated by the top eigenvalue $\lambda_1$. If $\lambda_1>1$, $J^n$ will diverge (gradient explosion). If $\lambda_1 < 1$, $J^n$ will converge to zero (gradient vanishing).


In RNN [ref](https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-recurrent-neural-networks), we face the similar gradient explosion/vanishing issue. Consider the following updating equations for RNN:

\begin{align}
z^{(l)} &= Wx^{(l)} + b \\
h^{(l)} &= \sigma(z^{(l)}) \\
z^{(l+1)} &= Uh^{(l)} + Vx^{(l+1)} \label{rnn_eq}\tag{Eq 1.3}
\end{align}

Solving the weights involves computing the following Jacobian matrix

\begin{equation}
\frac{\partial h^{(T)}}{\partial h^{(0)}} = \frac{\partial h^{(T)}}{\partial h^{(T-1)}}\frac{\partial h^{(T-1)}}{\partial h^{(T-2)}}...\frac{\partial h^{(1)}}{\partial h^{(0)}}
\end{equation}

which involves product of $T$ terms, hence resulting in the same problem as before. The gradient explosion and vanishing problem can be alleviated by using **residual connection**. Consider the following updating equations for a residual network (**ResNet**)
\begin{align}
z^{(l)} &= Wx^{(l)} + b \\
h^{(l)} &= \sigma(z^{(l)}) \\
z^{(l+1)} &= Uh^{(l)} + Vx^{(l)} \label{resnet_eq}\tag{Eq 1.4}
\end{align}

The subtle difference \ref{rnn_eq} and \ref{resnet_eq} is the $x$ term in the last line. For RNN, $x$ is the next variable in the sequence $x^{(l+1)}$ where in ResNet, $x$ is the same variable from the layer $x^{(l)}$. The reason by this helps mitigate gradient vanishing/explosion can be understood by writting \ref{resnet_eq} in a more general form

\begin{equation}
x^{(l+1)} = x^{(l)} + F(x^{(l)})
\end{equation}

The Jacobian becomes

\begin{equation}
J^{(l+1)}_{ij} = \frac{\partial x^{(l+1)}_i}{\partial x^{(l)}_j} = I + \frac{\partial F(x^{(l)})}{\partial x^{(l)}}
\end{equation}

Hence, if $\frac{\partial F}{\partial x^{(l)}}=0$, the product of Jacobian will be close to identity and the gradient will explode/vanish at a lower rate. For a sigmoid activation function $F=\sigma$, $\partial F/ \partial x$ is the largest when $x$ is close to zero. Therefore, batch normalization helps ensure $\frac{\partial F}{\partial x^{(l)}}$ close to zero.

### 1.2 Signal propagation in DNN [ref](https://arxiv.org/pdf/1606.05340.pdf)

This section concerns the propagation of input information along layers in a DNN. The framework presented below will help us better understand 
- why DNN are powerful (i.e. high expressivity) 
- the source of gradient vanishing/explosion
- learning dynamics of DNN
- importance of different NN initializations and architectures [ref](https://arxiv.org/pdf/1802.09979.pdf)

First let's define the feedforward dynamics of a random DNN:

\begin{align}
h_i^{(l)} &= \sum_j w_{ij}^{(l)} x^{(l-1)}_j + b_i \\
x^{(l)}_i &= \phi(h_i^{(l)})
\end{align}

where for random DNN, $w_{ij}$ and $b_i$ are drawn independently from a Gaussian distribution of variance $\sigma_w^2$ and $\sigma_b^2$ respectively. As we will see, these two parameters are important in determining the feedforward dynamics of a random DNN. To qualitatively understand how signal propagates through the DNN, define the following two variables 

\begin{align}
q^l &\equiv \langle (h^{(l)}_i)^2\rangle \\
q^l_{\alpha \beta} &\equiv \langle h^{(l)}_i(x^0_\alpha)\ h^{(l)}_i(x^0_\beta)\rangle
\end{align}

where $\alpha$ and $\beta$ are indices to two samples. The expectation operator $\langle * \rangle$ averages over $w_{ij}$ and $b_i$. For DNN with large enough layer width, $\langle * \rangle$ can be replaced by the average over neuron in the same layer, i.e. assuming a self-averaging property. The first quantity tracks how the average length of the embedded vector changes along the network. For the second quantity, the normalized form of $q^l_{\alpha \beta}$, $c^l_{\alpha \beta} = q^l_{\alpha \beta}\ /\ \sqrt{q^l_{\alpha \alpha} q^l_{\beta \beta}}$ then measures the average correlation at layer $l$ of two inputs $x^0_\alpha$ and $x^0_\beta$. 

It was shown that the changes in $q^l$ and $q^l_{\alpha \beta}$ along the layers are governed by the following recursive equations

\begin{align}
q^l &= \mathcal{V}(q^{l-1} \mid \sigma_w, \sigma_b) \label{q_recursive}\tag{Eq 1.2.1}\\
c^l_{\alpha \beta} &= \mathcal{C}(c^{l-1}_{\alpha \beta}, q^l_{\alpha \alpha}=q^*, q^l_{\beta \beta}=q^* 
\mid \sigma_w, \sigma_b) \label{c_recursive}\tag{Eq 1.2.2}
\end{align}

where \ref{q_recursive} converges to a fixed point $q^*$ as $l\to\infty$. The recursive equation for $c^l$, \ref{c_recursive} sets the norm of $q^l_{\alpha \alpha}$ and $q^l_{\beta \beta}$ equals to $q^*$ so they do not change during the iteration. The following summarize different fixed points of $q^l$ as a function of $\sigma_w$ and $\sigma_b$

|$\sigma_w$|$\sigma_b$|$q^*$|stability|
|-----|-----|--------|-----|
|<1 |0  |0| stable |
|>1  |0  |0| unstable |
|>1  |0  |>0| stable |
|>0  |$\neq 0$  |>0| stable |

For $c^l$, it is shown that $c^*=1$ is always a fixed point. To quantify the stability of $c^l$, we need to describe how this correlation changes over layers, define

\begin{align}
\chi = \frac{\partial c^l_{\alpha \beta}}{\partial c^{l-1}_{\alpha \beta}}\bigg|_{c=c^*=1}
\end{align}

If $\chi > 1$, $c^*$ is unstable (if $c^{l-1}$ decreases, $c^l$ will further move away from 1) and there exists another fixed point $c^*<1$. If $\chi < 1$, $c^*=1$ is stable (if $c^{l-1}$ decreases, $c^l$ will move back to 1). In the $\sigma_w$-$\sigma_b$ space, $\chi=1$ represents the critical line that separate the unstable, 'chaotic' (it's more like a simple bifurcation instead of a true chaos) phase and the stable, ordered phase. In the 'chaotic' phase, the correlation of two inputs will decrease over layers. In the ordered phase, the two inputs will end up to be the same over enough layers.

It is shown in [ref](https://arxiv.org/pdf/1711.00165.pdf) that a deep, infinite-wide NN can be expressed as a Gaussian Process (GP), whose kernel at layer $l$ and $l-1$ are related by a recursive formula in the same form as \ref{c_recursive}.

### 1.3 Understand Expressivity through Signal propagation in DNN

#### Preliminaries: Riemannian geometry

Imagine the feature input data $\{\pmb x_i\}_{i=1}^N$ forms a $D_M$ dimensional manifold in a $D$ dimensional space where $D$ is simply the dimension of $\pmb x_i$. Let's parametrize a subset of data points on this manifold by a scalar parameter $\theta$ so that if traces a 1D curve $\pmb x(\theta)$ on the manifold. As $\pmb x_i$ is passed along the DNN, the manifold will be morphed into a different shape. In layer $l$, the 1D curve will be mapped to a different curve $\pmb h^{(l)}(\theta) = \pmb h^{(l)}(\pmb x(\theta))$. One can define the characteristics of the curve by two metrics

\begin{align}
g^E(\theta) &= \partial_\theta \pmb h(\theta) \cdot \partial_\theta \pmb h(\theta) = \pmb v(\theta)\cdot \pmb v(\theta) \label{euclidean_metric}\tag{Eq 1.3.3}\\
g^G(\theta) &= \partial_\theta \pmb{\hat{v}}(\theta)\cdot \partial_\theta\pmb{\hat{v}}(\theta) 
\label{gauss_metric}\tag{Eq 1.3.4}\\
\end{align}

\ref{euclidean_metric} (Euclidean metric) can be understood as the speed along the curve at point $\theta$ and \ref{gauss_metric} (Gauss metric) as the acceleration along the curve at point $\theta$. Note that $\pmb v(\theta)$ points tangentially along the curve and $\partial_\theta\pmb v(\theta)$ points perpendicularly to the curve. The two metrics are related by the curvature $\kappa(\theta)$ 

\begin{align}
g^G(\theta) = \kappa^2(\theta) g^E(\theta) \label{euclidean_gauss_curvature}\tag{Eq 1.3.5}
\end{align}

As a simple example to understand \ref{euclidean_gauss_curvature}, consider $\pmb h(\theta)$ is a circle. Under linear expansion $\pmb h(\theta) \to \chi \pmb h(\theta)$, $g^E(\theta)$ increases by a scaler $\chi$ whereas the curvature $\kappa$ decreases by $1/\sqrt{\chi}$. Therefore $g^G(\theta)$ remains unchanged. This is simply understood as when we expand the radius of the circle, the curvature decreases.

#### Exponential expressivity in the chaotic phase

In the ordered phase $\chi < 1$, since the fixed point $c^*=1$, the curve will eventually collapse to a single point (all inputs have maximum correlation). In the chaotic phase $\chi > 1$, the propagation of the curve $\pmb h(\theta)$ behaves very differently. Even as $g^E(\theta)$ expands exponentially in depth, $\kappa$ does not necessarily decrease. Its growth depends on curvature of single neuron nonlinearity. Therefore the Gaussian metric also grows exponentially. Intuitively, the stretching of the curve happens together with the increasing convolution of the curves in other dimensions, thus filling up the hidden representation space.

The expressivity of DNN can now be understood in two ways: The last layer of the DNN is essentially a linear regression and therefore have a linear decision boundary $\pmb w \cdot \pmb h^{(L)} - b = 0$ for the input $\pmb h^L(\theta)$, which as discussed above, has a complicated geometry. Even the decision boundary has simple geometry, the input has a complicated geometry. This is similar to the concept of kernel SVM.

The second way to understand expressivity is to look at the first layer of the DNN where the input $\pmb x(\theta)$ is relatively smooth but the decision boundary $\pmb w \cdot \pmb h^{(L)}(\pmb x) - b = 0$ is no longer linear (nonlinearity coming from $\pmb h^{(L)}(\pmb x)$). In fact, it can be shown the geometry of the decision boundary becomes increasing complicated. Define the decision boundary $G(\pmb x)$ to be a collection of point where

\begin{align}
G(\pmb x) = \{\pmb x \mid \pmb w \cdot \pmb h^{(L)}(\pmb x) - b = 0\}
\end{align}

Intuitively, a point $x^*$ on the decision boundary manifold can be approximated by a paraboloid with a quadratic form $H$ (the normalized Hessian matrix) whose $N-1$ eigenvalues are the principal curvatures. Numerically, it is shown that a subset of the principal curvatures grow exponentially with depth. This implies the decision boundary manifold becomes more complex and curved.

#### Forward propagation and expressivity [ref](https://arxiv.org/pdf/1611.01232.pdf)

As discussed in Section 1.3, both $q^l$ and $c^l$ will converge to their respective fixed points at deep layers. The rate at which the quantities converge describe the depth scales of how far information propagates in a DNN. Intuitively, the slower the quantities converge, the further information propagates along the depth. It is shown that both quantities converge exponentially

\begin{align}
| q^l_{\alpha \alpha} - q^* | &\sim e^{-l/\xi_q} \\
| c^l_{\alpha \beta} - c^* | &\sim e^{-l/\xi_c}
\end{align}

Similar to the gradient length scale $\xi_\nabla$, $\xi_c$ has the following form

\begin{align}
\xi_c^{-1} = -\log \chi
\end{align}

and therefore has a order-to-chaos transition at $\chi=1$. $\xi_c$ diverges and information persists indefinitely with depth.

#### Backward propagation and gradient vanishing/explosion [ref](https://arxiv.org/pdf/1611.01232.pdf)

In section 1.3.1, we discussed forward propagation of signal. Here, we show the duality between the forward propagation of signal and backprop of gradient. This allows us to intuitively understand the source of gradient vanishing/explosion.

Reiterating the DNN forward equations

\begin{align}
h_i^{(l)} &= \sum_j w_{ij}^{(l)} x^{(l-1)}_j + b_i \\
x^{(l)}_i &= \phi(h_i^{(l)})
\end{align}

Let's rewrite the backprop equation in a recursive way

\begin{align}
\frac{\partial \mathcal{L}}{\partial w^{(l)}_{ij}} &= \frac{\partial \mathcal{L}}{\partial h^{(l)}_{i}}\frac{\partial h^{(l)}_{i}}{\partial w^{(l)}_{ij}} = \delta^{(l)}_i \phi(h^{(l-1)}_j)\\
\delta^{(l)}_i &= \phi'(h^{(l)}_i)\sum_j \delta^{(l+1)}_j w^{(l+1)}_{ji} \label{backprop_recursive}\tag{Eq 1.3.6}
\end{align}

The second line of \ref{backprop_recursive} shows the iterative map of $\delta^{(l)}_i$ with depth. Define 

\begin{align}
\tilde q_{\alpha \alpha}^l \equiv \mathbb{E}[(\delta^{(l)}_i)^2]
\end{align}

Roughly speaking, applying expectation to the first line of \ref{backprop_recursive}

\begin{align}
\mathbb{E}\Big[\Big(\frac{\partial \mathcal{L}}{\partial w^{(l)}_{ij}}\Big)^2\Big] &\approx  \mathbb{E}[(\delta^{(l)}_i)^2] \mathbb{E}[\phi^2(h^{(l-1)}_j)] = \tilde q_{\alpha \alpha}^l \mathbb{E}[\phi^2(h^{(l-1)}_j)] 
\end{align}

where the approximation relationship is due to using mean-field approximation (no correlation between $(\delta^{(l)}_i)^2$ and $\phi^2(h^{(l-1)}_j)$). The expected magnitude of the gradient is then proportional to $\tilde q_{\alpha \alpha}^l$. It is shown that

\begin{align}
\tilde q_{\alpha \alpha}^l = \tilde q_{\alpha \alpha}^{l+1}\frac{N_{l+1}}{N_l}\chi
\end{align}

which implies $\tilde q_{\alpha \alpha}^l$ grow/shrink exponentially as one backprop towards the first layer. Therefore $\tilde q_{\alpha \alpha}^l$ can be written as

\begin{align}
\tilde q_{\alpha \alpha}^l &= \tilde q_{\alpha \alpha}^L e^{-(L-l)/\xi_\nabla}\\
\xi_\nabla^{-1} &= -\log \chi
\end{align}

Now gradient vanishing/explosion can be understood as follow. In the ordered phase $\chi < 1$ and $\xi_\nabla > 0$, the magnitude of the gradient decreases exponentially as one backprop towards the first layer and is therefore expected to vanish. On the other hand, in the chaotic phase $\chi > 1$, $\xi_\nabla < 0$, the magnitude of the gradient increases exponentially and hence explodes. At the critical line $\chi = 1$, the system reaches a fine balance between gradient explosion and vanishing.

### 1.4 Effects of different initializations and NN architectures

#### Preliminaries: Jacobian spectrum

Another way to understand $\chi$ is as follow. from \ref{chain_rule_1}, the input-to-output Jacobian can be written as

\begin{align}
J &\equiv \prod_{l=1}^L J^{(l)} \\
&= \prod_{l=1}^L D^{(l)} W^{(l)}
\end{align}

where $D^{(l)}_{ij} = \phi'(h^{(l)}_i)\delta_{ij}$. The distribution of singular value of $J$ encodes information about feedforward dynamics of the NN. For example, the second moment of the distribution of the squared singular value of $J^{(l)}$ is identical to $\chi$, which can be written in the following form

\begin{align}
\chi = \frac{1}{N}\langle \text{Tr} (DW)^\top(DW) \rangle
\end{align}

here since $D^{(l)}$ and $W^{(l)}$ are independent of $l$, we simply denote layerwise $D$ and $W$ without superscript. For the full input-to-output Jacobian $J = \prod_{l=1}^L J^{(l)}$, the second moment of the distribution of the squared singular value is $\chi^L$. Therefore, the wider the squared singular value distribution, the larger the gradient explodes.

#### Orthogonal vs Gaussian initialization [ref](https://arxiv.org/pdf/1802.09979.pdf)

The first initialization concerns the distribution of the weights: should we start from a Gaussian distributed weights or orthogonal weights? [Ref](https://arxiv.org/pdf/1802.09979.pdf) showed that orthogonal weight initialization is preferred as it produces a more stable Jacobian spectrum of a very deep network.

#### Bounded vs unbounded activation functions 

#### Dropout

#### Batch normalization

Batch normalization is first proposed to avoid **internal covariate shift** which is understood as a shift in the distribution of the data in the deeper layer of the network due to the change in weights in the earlier layers. Consider layer $l$ of the NN of the $t$-th iteration

\begin{align}
x^{(l)}_i &= F\Big(\sum_j W_{ij}^{(l)}(t) x^{(l-1)}_j + b^{(l)}_j \Big) \\
&= G(W^{(1)}(t),W^{(2)}(t),...W^{(l)}(t), b^{(1)}(t),b^{(2)}(t),...b^{(l)}(t), x^{(0)})
\end{align}

where $G$ is the composite of $l$ number of $F$, and is a function of the input $x^{(0)}$. Even with the same $x^{(0)}$, i.e. data drawn from the same batch, $x^{(l+1)}$ will be different due to the changes in $W_{ij}^{(k)}(t), \ \forall k < l$ over time. Consequently, the distribution of $x^{(l+1)}$ will tend to shift over time. Consider the optimization of the subsequent layers (layer $>l$), the loss function will take the form

\begin{align}
\mathcal{L} &= \hat{\mathcal{L}}(W^{(l+1)},W^{(l+2)},...W^{(L)}, b^{(l+1)},b^{(l+2)},...b^{(L)}, x^{(l+1)})
\end{align}

The argument is that by normalizing the data in each batch, we improve the stability of the loss function $\hat{\mathcal{L}}$ as $x^{(l+1)}$ will not change much after the normalization. Therefore, ICS can be defined as

\begin{align}
\text{ICS} = \| &\nabla_{W^{(l)}(t)}\mathcal{L}(W^{(1)}(t),W^{(2)}(t),...,W^{(l)}(t),...,W^{(L)}(t); x, y) - \\
&\nabla_{W^{(l)}(t)}\mathcal{L}(W^{(1)}(t+1),W^{(2)}(t+1),...,W^{(l-1)}(t+1),W^{(l)}(t),...,W^{(L)}(t); x, y)\|
\end{align}

##### Pros and cons of batch norm
As the activation function is mostly flat far away from the bias, covariate shift can lead to gradient vanishing problem. By placing batch morm between $W$ and $\phi$, the data is re-centering before the activation function and therefore **improve the gradient vanishing problem**. 

Batch norm also **improves the robustness of NN to different hyperparameters**. This is because if the loss function is more smooth [ref](https://arxiv.org/pdf/1805.11604.pdf), the gradients will be more predictable, i.e. the path to minimum would not involves too much sharp turns. This allows for using a wider range of values of learning rate without risking overshooting, which might cause oscillation or divergence.

Batch normalization also **speeds up SGD convergence** by reseting the scale of the variables in the loss function so that the (fixed) learning rate will not result in oscillations or slow convergence (similar to the idea of normalizing the data for any GD algorithm). *For algorithm with adaptive learning rate, will batch normalization still be useful? For example, in Newton's method, the learning rate is adaptive and changes depending on the Hessian matrix.*

The downside of batch normalization is it forces the data to distributed around the linear region of the activation function and will **reduce the representational power of the network**. One way to improve that is to introduce two *learnable* parameters to further rescale and re-center the hidden layer outputs.

Note that the ICS view of batch norm is challenged by [ref](https://arxiv.org/pdf/1805.11604.pdf), which made the claim using Deep Linear Network and fully batch training data (i.e. no stochasticity).

##### Batch norm and signal propagation depth

It is hypothesized [ref](https://arxiv.org/pdf/1611.01232.pdf) that applying batch norm helps increasing the depth scales by controlling the variance of the network weights. Since the scale of the input to each layer is normalized, this in turn control the variance of the weights in that layer. A network with lower network weight variance is less likely to be in the chaotic region where the network is more susceptible to gradient explosion.

## 2. Optimization

### 2.1 Proximal optimization

[ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/slides/lec03.pdf)

Given a function $f(\theta)$, one way to find the minimum of the function is to use gradient descent. Starting from a given point $\theta_0$, we can derive the direction of parameter update using first-order Taylor expansion

\begin{align}
f(\theta) &\approx f(\theta_0) + \nabla f(\theta^0)^\top (\theta - \theta_0)\\
\theta^* &= \arg \min_\theta( f(\theta_0) + \nabla f(\theta^0)^\top (\theta - \theta_0)) \label{proximal_taylor_1}\tag{Eq 2.1.1}
\end{align}

The optimal $\theta^*$ in \ref{proximal_taylor_1} is infinite in the direction of $-\nabla f(\theta^0)$ as the optimization problem is unbounded due to the linearization of the function. To control of runaway solution, one can enforce that the step size taken should also be small be introducting an distance term to the optimization problem

\begin{align}
\theta^* &= \arg \min_\theta( f(\theta_0) + \nabla f(\theta^0)^\top (\theta - \theta_0)) + \|\theta - \theta_0\|^2 \label{proximal_taylor_1_dist}\tag{Eq 2.1.2}
\end{align}

To find $\theta^*$ in \ref{proximal_taylor_1_dist}, one take the gradient on the objective function and will recover the gradient descent equation

\begin{align}
\theta_{t+1} = \theta_t - \nabla f(\theta_t)
\end{align}

In general, the distance constraint does not have to be a Euclidean distance in \ref{proximal_taylor_1_dist}. Any distance measure $\rho(\delta, \theta^k)$, with a Lagrange multiplier $\lambda$ can be used.

#### Approximation 1 - natural gradient
By taking the infinitestimal step size limit ($\lambda \to 0$), taking the first-order Taylor expansion of $f(\theta)$ and second-order Taylor expansion of the distance function $\rho$ will give the natural gradient equation

\begin{align}
\delta^* &= \arg \min_\delta( f(\theta_0) + \nabla f(\theta^0)^\top \delta + \frac{\lambda}{2} \delta^\top G \delta)\\
\delta^* &= -\lambda^{-1}G^{-1}\nabla f(\theta_0) \label{proximal_natural_gradient}\tag{Eq 2.1.3}
\end{align}

where $G = \nabla_\delta^2 \rho(\delta, \theta^k)$

#### Approximation 2 - damped Newton method
For the next approximation, take second-order Taylor expansion of both $f(\theta)$ and $\rho$. This results in the following update rule

\begin{align}
\delta^* = -(H + \lambda G)^{-1}\nabla f(\theta_0) \label{proximal_damped_newton}\tag{Eq 2.1.4}
\end{align}

where when Euclidean metric is used for $\rho$, $G=I$ and the update becomes the **damped Newton method (trusted region method)**. 

Note that adaptive gradient algorithms (Appendix A2) has a very similar form to \ref{proximal_damped_newton}, except (1) the way the Hessian is approximated (using the diagonal of the empirical FIM) and (2) the gradient is scaled by $H^{-1/2}$ instead of $H^{-1}$ in \ref{proximal_damped_newton}. To the first point, empirical FIM is in general not a good approximation to the Hessian (it is only a good approximation if the model is optimal). The second point is to address the impact of the first point by rescaling the Hessian approximation by a square root [ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/readings/L05_normalization.pdf).

### 2.2 Gauss-Newton matrix - approximation to Hessian matrix

[ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/readings/L02_Taylor_approximations.pdf)

The motivation for building a Gauss-Newton matrix is to approximate the Hessian matrix of a loss function in the weight space using the Hessian matrix of the loss function in the output space. To understand this, the former loss function is of the form

\begin{align}
\mathscr{L}(\theta) =  L(y, f(x, \theta))
\end{align}

whereas the latter loss function is of the form

\begin{align}
\mathcal{L}(y, z) =  L(y, z)
\end{align}

with $z = f(x, \theta)$. Therefore, chain-rule is needed to derive the Hessian in the weight space, $\nabla_\theta^2 \mathscr{L}$. Using chain rule, the Hessian matrix of $\mathscr{L}$ becomes

\begin{align}
\nabla_\theta^2 \mathscr{L} = J_{z\theta}^\top H_z J_{z\theta} + \sum_i \frac{\partial L}{\partial z_i}\nabla_\theta^2 [f(x, \theta)]_i \label{Hessian_approx_chain_rule}\tag{Eq 2.2.1}
\end{align}

where $J_{z\theta}$ is the Jacobian matrix $J_{z\theta} = \partial z/\partial \theta = \partial f/\partial \theta$ and $H_z$ is the Hessian matrix in the output space, i.e. $H_z = \nabla^2_z L$. \ref{Hessian_approx_chain_rule} consists of two terms. The first term can be understood as linearizing the model function $f$ at $\theta$ and performing a quadratic approximation to the loss function $L$ at $\theta$ as it only involves Jacobian (first-order term) of $z$ and Hessian (second-order term) of the loss function. The first term 

\begin{align}
G \equiv J_{z\theta}^\top H_z J_{z\theta}
\end{align}

is known as the **Gauss-Newton matrix** and the approximation of the weight-space Hessian using the Gauss-Newton matrix is known as the Gauss-Newton approximation.

<img src="figures/C11/gauss_newton_approx.png" width=500>


The justification of the Gauss-Newton approximation is the assumption that the second term is negligible, which will be the case if all training samples fit the model perfectly. If this happen, $\partial L/\partial z_i$ will be zero for all $i$ and the second term can therefore be neglected. This implies that Gauss-Newton approximation loses its validity for weak models.

#### Benefit of using Gauss-Newton matrix

Since $G$ only involves first-order derivative of the function $f$, activation function such as ReLU can be used. This is not the case when computing $H_\theta$ since it will involves second-order derivative of the activation function, which in the case of ReLU, will give zero.

### 2.3 Output space gradient descent

The optimization problem in the weight space is usually much more difficult than that in the output space. One famous example is the optimization of the Rosenbrock function [ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/slides/lec03.pdf). GD in the parameter space is shown to have difficulty in finding the global minimum whereas GD in the output space finds the global minimum easily.

#### Pullback

## 3. Learnability of DNN

### 3.1 Challenge #1: Gradient vanishing/explosion

[ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/slides/lec08.pdf)

As mentioned in the Section 1, gradient vanishing/explosion is akin to the largest singular value of the Jacobian matrix. Since repeat multiplication of Jacobian matrix is dominated by the largest singular value, its value determine whether gradient will vanish or explode.

Another view of gradient vanishing/explosion can be understood by its relationship to sharpness of the minima, which is quantified by the largest eigenvalue of the Hessian matrix, of the loss function landscape. To understand this, consider the Gauss-Newton approximation of the Hessian matrix

\begin{align}
H \approx G = \mathbb{E}[J_{zw} H_z J_{zw}]
\end{align}

where $H_z=\nabla_z \mathcal{L}$. For a square error loss, $H_z = \pmb 1$ and $G$ becomes the classical Gauss-Newton matrix

\begin{align}
H \approx \mathbb{E}[J_{zw}J_{zw}]
\end{align}

It is speculated that the largest eigenvalue of $H$ is the largest singular value of $J_{zw}$. This relates large singular value of $J$, which indicates the network susceptibility to input perturbation and potential of gradient explosion, to the sharpness of the minima of the loss function. It is shown in [ref](https://arxiv.org/pdf/1609.04836.pdf) that sharp minima are less able to generalize (to unseen data) then flat minima. Hence there is a preference of converging to a flatter minimum (see the batch size discussion in section 3.5)

### 3.2 Challenge #2: Proliferation of saddle points
[ref](https://ganguli-gang.stanford.edu/pdf/14.SaddlePoint.NIPS.pdf)

It is mentioned in [ref](https://ganguli-gang.stanford.edu/pdf/14.SaddlePoint.NIPS.pdf) that saddle points are common in high dimensional space. This could be understood using Random Matrix Theory (RMT). Assuming the Hessian matrix of a given critical point in the energy manifold is a Gaussian random matrix, the distribution of eigenvalues follow a semi-circle law, centered at zero. This implies a critical point is likely a saddle point (mixed of positive and negative eigenvalues). As the energy level of the critical point decreases, the distribution shifted far to the right, hence the critical point is more likely to be a global minimum (no direction with negative curvature). On the other hand, as the energy level of the critical point increases, the distribution shifted to the left, with a concentration at the value zero. This implies the proliferation of saddle point with plateau surrounding it (implied by the zero eigenvalues). The plateau makes gradient descent very slow. Below we summarized how different optimizers behave near a saddle point.

#### (Vanilla) gradient descent
- slow descent when the gradient is small
- always move away from saddle points

#### Newton method
- gradient ascent when curvature is negative, therefore will move towards saddle points

#### Trusted region approach
- damp the Hessian by adding a constant $\alpha$ to the diagonal, effectively removing negative curvatures
- with large damping factor $\alpha$ can result in slow GD

#### Truncated Newton method, BFGS approximation
- ignore negative curvature directions
- cannot escape saddle point since negative curvature directions are ignored

#### Natural gradient descent
- natural gradient favors moves that result in small change to the model outcome, see A3 for detail

### 3.3 Challenge #3: Failure of Gradient based algorithms 
[ref](https://arxiv.org/pdf/1703.07950.pdf)

### 3.4 Challenge #4: Catastrophic intereference/forgetting
[ref1](https://www.pnas.org/doi/10.1073/pnas.1611835114), [ref2](https://en.wikipedia.org/wiki/Catastrophic_interference#Elastic_weight_consolidation)

### 3.5 NN tuning best practices
[ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/readings/L05_normalization.pdf)


#### Batch size

- Large batch vs small batch size

||small batch size | large batch size |
|---|---|---|
|gradient| noisy | less noisy |
|weight recency per descent| more recent weight | older weight |
|# weight update required| more | less |
|parallelism| less | more$^*$ |
|effective learning rate$^{**}$| small | large | 

$^*\ $ the benefit of larger batch size to parallelism plateau for large enough batch size (called **maximal data parallelism**)

$^{**}$ Algorithms that approximate Hessian matrix (e.g. ADAM, RMSprop). This is due to batch empirical FIM inversely proportional batch size.

- When using SGD, use large learning rate early to get close to the optimum, then reduce learning rate to reduce fluctuation [ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/slides/lec07.pdf)
- Smaller batch size make use of more recent update weight, hence is better than larger batch size. For example, \begin{align} \theta^{(k)} &\leftarrow \theta^{(k-1)} - \alpha \nabla \mathcal{L}^{(k)}(\theta^{(k-1)})\\ \theta^{(1)} &\leftarrow \theta^{(0)} - \frac{\alpha}{S} \sum_{k=1}^{S}\nabla\mathcal{L}^{(k)}(\theta^{(0)})\end{align} where the first line is updated $S$ (batch size) times and the second line is updated once. At each step, a one-sample batch GD (first line) uses the most recent weight from the previous step to evaluate the loss function.
- The number of FLOP operations are independent of batch size. However, larger batch size can make good use of parallel computing (parallelize matrix multiplication). 
- With BN, the batch size affects the amount of stochastic regularization.
- It is shown in [ref](https://arxiv.org/pdf/1609.04836.pdf) that using a larger batch with a SGD type optimizer leads to the convergence of sharper minima, which has poor generalization. This implies that noise in gradient has an implicit regularization effect, i.e. converging to flat minima. This is however challenged in [ref](https://arxiv.org/pdf/1811.03600.pdf) which highlight two confounders in the literature:
    - batch norm creates explicit regularization effect and is more pronounced for smaller batches
    - some papers fixed the number of epochs, so models with larger batch are trained with fewer iterations
    
#### Learning rate
- For Adam, RMSprop, etc., 
    - larger batch sizes result in more gradient noise, and hence smaller steps
    - larger batch sizes increase the effective learning rate (batch empirical FIM inversely proportional batch size)
- With homogeneous normalizers such as BN, WN, and LN, there is **implicit learning rate decay**
    - the norm of the weights $\|w^{(k)}\|^2$ increase like $\sqrt{k}$ over $k$-GD steps
    - the norm of the weights affects the effective learning rate. This can be understood in the following figure <img src="figures/C11/BN_effective_LR.png" width=500> where the following identity of homogeneous function, which is true for BN, WN and LN, is exploited \begin{align} \nabla \mathcal{L}(\gamma \pmb w) = \gamma^{-1} \nabla \mathcal{L}(\pmb w)\end{align}
    - the learning rate decays implicitly like $k^{-1/2}$. This turns out to be the form of learning rate scheduler for popular optimizer such as Adagrad (Appendix A2)
    - explicit learning rate only have transient effect and is only relevant in the beginning of the learning phase
    
    

- Different optimizers (SGD, Adam, K-FAC, etc.) have different effective learning rate schedules when combined with homogeneous normalizers.
- For many architectures, weight decay fundamentally affects the training dynamics, so it can’t be tuned independently of the optimizer (as we might expect for a regularization hyperparameter).




## 4. Self-attention as rank-lowering operation 
[ref1](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf)
[ref2](https://arxiv.org/pdf/2206.03126.pdf)
[ref3](https://arxiv.org/pdf/2006.04768.pdf)
[ref4](https://arxiv.org/pdf/2109.04553.pdf)

### 4.1 Rank collapsing of pure self-attention

It was shown in [ref](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf) that without MLP and skip-connection, a pure attention architecture will lead to rank-collapsing , which is a phenomenon where the all of the embedded token are the same (and hence the output embedded matrix is a rank-one matrix). Rank collapsing is therefore the most extreme case of rank-lowering. 

The first paper to this phenomenon [ref](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf) pointed out a doubly exponential convergence with depth $L$ of the embedding matrix to a rank-1 matrix. The exponential convergence was hypothesized to be due to the fact that the self-attention heads mix tokens faster when formed from a low-rank matrix, leading to a cascading effect of rank-reduction.

It was discussed in [ref](https://arxiv.org/pdf/2206.03126.pdf) that rank-collapsing causes the gradients of the queries and keys to vanish at initialization. This is reminiscent to the ordered phase in DNN, where the input manifold is contracted into a single point $(c^l \to c^* = 1)$.

#### Skip-connection
Skip (or residual) connection (with depth-dependent scaling) is shown to help approximately preserve the cosine similarity of the tokens since it just simply an identity operator. It is argued in [ref](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf) that the existence of skip connection factorially increases the number of paths of length $l$ from

\begin{align}
|\mathcal{P}_l| = H^l
\end{align}

for a $l=L$-layer pure self-attention network with $H$ heads in each layer to

\begin{align}
|\mathcal{P}_l| = {L \choose l} H^l
\end{align}

for a $L$-layer self-attention network with $H$ heads and skip-connection in each layer. $l$ is the number of layers that do not pass the skip-connection. It was hypothesized that the presense of paths that utilize skip-connections helps reduce rank collapsing.

[Ref](https://arxiv.org/pdf/2206.03126.pdf) further investigated the effect of skip-connection on the average cosine similarity of a pair of tokens (similar to $c^l_{12}$ in Section 1). Instead of just considering using skip-connection, they introduce two parameters ($\alpha_1$ for after self-attention head, $\alpha_2$ for after the MLP) to control for the degree of skip-connections. They showed that by scaling down the signal through the self-attention head and MLP at layer $l$ by a factor of $1/\sqrt{l}$ (hence increasing residual connection proportionally) allows for the preseverance of the cosine similarity of the pair of tokens at deep depth $L\to\infty$:

\begin{align}
\mathbb{E}[C(\pmb X^{(L)})] &= (1+\alpha_1^2)^L(1+\alpha_2^2)^L C(\pmb X)\\
\lim_{L\to \infty}\mathbb{E}[C(\pmb X^{(L)})] &= \Big(1+\Big(\alpha_1/\sqrt{L}\Big)^2\Big)^L \Big(1+\Big(\alpha_2/\sqrt{L}\Big)^2\Big)^L C(\pmb X) \\
&= e^{\tilde \alpha_1 + \tilde \alpha_2}C(\pmb X)
\end{align}

where $\tilde \alpha_1 \equiv \alpha_1/\sqrt{L}$ and $C(\pmb X) = \sum_{k,k'}\langle \pmb X^l_k \pmb X^l_{k'}\rangle$. Similar conclusion was reached for the preseverance of the token norm at deep depth. In fact, they also concluded that without the depth-dependent scaling $1/\sqrt{l}$, the correlation of the tokens will increase and again result in rank collapsing. The $1/\sqrt{L}$ scaling term has been proposed in previous studies [ref](https://arxiv.org/pdf/1803.01719.pdf) to stablize the residual networks.

#### MLP

It is shown that MLP slightly helps reducing the rank-1 convergence speed. The upper bound of how severe rank-collapsing happens in layer $L$ is inflated by $\lambda^{\frac{3L-1}{2}}$, where $\lambda$ is the Lipschitz constant of the MLP. Therefore, the more powerful/expressive the MLP, the more it helps reducing the rank-1 convergence speed. 

From a signal propagation perspective, an more expressive an MLP, the more likely it is in the chaotic region, which corresponds to when gradient is more likely to explode. Since rank-reduction is a consequence of gradient vanishing, the gradient explosion effect in a more expressive MLP counteracts the gradient vanishing effect in the self-attention network.

#### Layer normalization

Layer norm is shown [ref](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf) to not useful in avoiding rank collapsing. Mathematically, layer norm does not change the way self-attention is written as a product of matrices and since elementary row/column operation (layer norm involves multiplying a diagonal matrix, therefore is an elementary matrix) does not change the rank of a matrix, it is argued that layer norm does not help avoiding rank collapsing.




### 4.2 Efficient approximations to self-attention mechanism

It was discussed in [ref](https://arxiv.org/pdf/2006.04768.pdf) and [ref](https://arxiv.org/pdf/2109.04553.pdf) that due to the rank-lowering characteristic of self-attention mechanism, it can be approximated by matrix factorization and by projecting the context mapping matrix into a lower dimensional space first, which usually take less time ($O(n^2)$) and space complexity then self-attention mechanism.

Recall that the self-attention mechanism can be written as

\begin{equation}
h_i^{(l+1)} = \sum_{ij}a_{ij} V h_j^{(l)}
\end{equation}

where the attention function $a_{ij}$ is 

\begin{equation}
a_{ij} = \sum_{ij} \text{softmax}(Q^{(l)} h_i^{(l)} \cdot K^{(l)} h_j^{(l)})
\end{equation}

and in matrix form

\begin{align}
H^{l+1} &= \text{softmax}\Big(\frac{Q^{(l)} H^{(l)} \cdot K^{(l)} H^{(l)}}{\sqrt{d_k}}\Big)H^{(l)}V \\
&= PH^{(l)}V \label{SA_matrix_form}\tag{Eq 2.1}
\end{align}

where $d_k$ is the dimension of the target subspace the operators $Q$ and $K$ project an embedded vector $h^{(l)}$ (of dimension $d_m$) into. It is shown in [ref](https://arxiv.org/pdf/2006.04768.pdf) that $P$ is a low-rank matrix whose spectrum is skewed. This suggests $P$ can be approximated by the top subset of singular values. The spectrum distribution in higher layers is more skewed than in lower layers, meaning that, in higher layers, more information is concentrated in the largest singular values and the rank of $P$ is lower. This insight is used to proposed approximations to more efficiently compute self-attention.

#### Linformer [ref](https://arxiv.org/pdf/2006.04768.pdf)

A linear self-attention mechanism aims to 

(1) project the $n\times n$ context mapping matrix $P$ into a lower dimensional space such that the $\dim(\tilde P)=n\times k$ through a projection matrix $E^{(l)}$; 

(2) project a $n \times d$-dimension $K^{(l)} H^{(l)}$ onto a $k \times d$-dimension $F^{(l)}H^{(l)}V$ through a projection matrix $F^{(l)}$. \ref{SA_matrix_form} is modified to

\begin{align}
H^{(l+1)} &= \text{softmax}\Big(\frac{Q^{(l)} H^{(l)} \cdot E^{(l)}  K^{(l)}H^{(l)}}
{\sqrt{d_k}}\Big)F^{(l)}H^{(l)}V\\
\end{align}

The time complexity is reduced from $O(n^2)$ to $O(nk)$.

It was mentioned in [ref](https://proceedings.mlr.press/v139/dong21a/dong21a.pdf) that due to the imposed low-rankness of the Linformer, rank-collapse happens even faster.

## Appendix

### A1 iterative equations for $q^l$ and $c^l$

Reiterating the DNN forward equations

\begin{align}
h_i^{(l)} &= \sum_j w_{ij}^{(l)} x^{(l-1)}_j + b_i \\
x^{(l)}_i &= \phi(h_i^{(l)}) \label{DNN_forward_pass}\tag{Eq A1.1}
\end{align}

Define 

\begin{align}
q^l &\equiv \langle (h^{(l)}_i)^2\rangle \\
&= \Big\langle \sum_{jk} w_{ij}^{(l)}w_{ik}^{(l)} x^{(l-1)}_j x^{(l-1)}_k \Big\rangle + \langle b_i^2\rangle \\
&\approx \sum_{jk} \langle w_{ij}^{(l)}w_{ik}^{(l)}\rangle \langle x^{(l-1)}_j x^{(l-1)}_k\rangle + \langle b_i^2\rangle \\
&= \sum_{j=1}^{N_{l-1}} \frac{\sigma_w^2}{N_{l-1}} (x^{(l-1)}_j)^2 + \sigma_b^2 \\
&= \frac{\sigma_w^2}{N_{l-1}}\sum_{j=1}^{N_{l-1}} \phi^2(h_j^{(l-1)}) + \sigma_b^2 \label{q_iterative_1}\tag{Eq A1.2}
\end{align}

where the approximation in the third line is from the mean-field ansatz. The reduction of the two sums into one from line 3 to 4 is due to the identity $\langle w_{ij}^{(l)}w_{ik}^{(l)}\rangle = \delta_{jk}\sigma_w^2/N_{l-1}$. Finally, the first line in \ref{DNN_forward_pass} is used in the last line in \ref{q_iterative_1}.

In the large layer width limit $N_{l-1}\gg 1$, the sum can be approximated by an integral. The random variable in the sum is $h_j^{(l-1)}$, which follows a Gaussian with zero mean and variance $q^{l-1}$. The last line in \ref{q_iterative_1} can be written as

\begin{align}
q^l &= \sigma_w^2\int dh N(0, q^{l-1})  \phi^2(h) + \sigma_b^2 \\
&= \sigma_w^2\int dh \frac{1}{\sqrt{2\pi q^{l-1}}}e^{-h^2/q^{l-1}}  \phi^2(h) + \sigma_b^2 \\
q^l &= \sigma_w^2\int dz \frac{e^{-z^2}}{\sqrt{2\pi}} \phi^2\Big(\sqrt{q^{l-1}}z\Big) + \sigma_b^2 \\
q^l &= \sigma_w^2\int Dz \phi^2\Big(\sqrt{q^{l-1}}z\Big) + \sigma_b^2 
\end{align}

where the change of variable $ h^2/q^{l-1} = z^2$ is used in the last line. To simplify the notation, define $Dz=e^{-z^2}/\sqrt{2\pi}$

Similarly for $q^l_{12}$

\begin{align}
q^l_{12} = \sigma_w^2\int Dz_1 Dz_2 \phi(u_1)\phi(u_2) + \sigma_b^2
\end{align}

where 

\begin{align}
u_1 &= \sqrt{q^{l-1}_{11}}z_1 \\
u_2 &= \sqrt{q^{l-1}_{22}}\Big[c^{l-1}_{12} z_1 + \sqrt{1- (c^{l-1}_{12})^2} z_2 \Big]
\end{align}

$u_2$ is constructed in this way to capture the correlation between the two random variables $z_1$ and $z_2$. The magnitude of the correlation is quantified by $c^{l-1}_{12}$.

From the previous definition of the recursive equation for $c^l_{12}$

\begin{align}
c^l_{12} &= \mathcal{C}(c^{l-1}_{12}, q^l_{11}=q^*, q^l_{22}=q^* 
\mid \sigma_w, \sigma_b) \sim q^l_{12}
\end{align}

The quantity $\chi$ can then be derived

\begin{align}
\chi &= \frac{\partial c^l_{12}}{\partial c^{l-1}_{12}}\bigg|_{c=c^*=1} \\
&= \sigma_w^2\int Dz_1 Dz_2 \phi(u_1) \frac{\partial}{\partial c^{l-1}_{12}} \phi(u_2) \\
&= \sigma_w^2\int Dz_1 Dz_2 \phi(u_1) \phi'(u_2)\frac{\partial u_2}{\partial c^{l-1}_{12}} \\
&= \sigma_w^2\int Dz_1 Dz_2 \phi(u_1) \phi'(u_2)\Big( z_1 - \frac{c^{l-1}_{12}}{\sqrt{1-(c^{l-1}_{12})^2}} z_2\Big)\\
&= \ ...
\end{align}

### A2 Gradient descent based optimizers

This section will review the most popular optimizers used to train NNs

#### Gradient Descent

\begin{equation}
\theta(t+1) \leftarrow \theta(t) -\eta_t \nabla_\theta E(\theta)
\end{equation}

where $\eta_t$ is the learning rate. Usually it will be a function ('scheduler') of iteration $t$ to ensure the algorithm converges to the solution. Since the change depends on the slope of $E(\theta)$, the steeper the loss function, the larger the update step will be. This is usually not ideal, and will likely move $\theta$ to a point that is not on the path to the minimum. To solve this issue, we need second-order information of the loss function.

#### Stochastic gradient descent

The stochasticity of SGD comes from the random subsample of data (a batch) used to compute the gradient. The main advantage of using SGD is that instead of calculating the full loss function landscape using the full dataset, only a random sample of the loss function (based on the random batch) is calculated. Since loss function for NN is highly non-convex, doing GD in the full loss function is susceptable to getting stuck at local minima. Since the loss function changes from batch to batch, it is less likely the parameter will get stuck at a local minima as a minimum point in one batch could be a non-minimum point in another batch. Another advantage computation of gradient using batches of data allows parallelization.

#### Newton's method

Taylor expand the loss function to second order near $\theta$

\begin{equation}
E(\theta + v) \approx E(\theta) + \nabla_\theta E(\theta) \cdot v + \frac{1}{2}v^T H v
\end{equation}

where $H$ is the Hessian matrix, defined as $H_{ij} = \partial^2 E /\partial \theta_i \partial \theta_j$. By choosing the optimal $v$ such that $\theta$ moves to the optimal solution in the next step and keeping up to the second order term

\begin{align}
\nabla_\theta E(\theta + v_\text{opt}) &\approx \nabla_\theta E(\theta) + H \cdot v_\text{opt}\\
\nabla_\theta E(\theta + v_\text{opt}) &= 0 \\
v_\text{opt} &= H^{-1}\nabla_\theta E(\theta)
\end{align}

which is similar to the equation of gradient descent. Therefore, the inverse of $H^{-1}$ can be interpreted as the optimal learning rate. Intuitively, $H$ capture the curature of the loss function at a point in different directions. Therefore, the more curvature the loss function is in certain direction, the slower $\theta$ will move in that direction.

The downside of Newton's method is that computing $H^{-1}$ is computationly very expensive as $H$ has the dimension of square of total number of parameters. Since $H$ can be viewed as the Jacobian of the gradient, one can use forward mode of autodiff on a computation graph computed from the reverse mode autodiff to compute the Hessian-vector products in linear time. This trick is called **forward-over-reverse** [ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/readings/L02_Taylor_approximations.pdf). However, this trick requires the existence of the computation graph from backprop.

A geometric understanding of Newton's method is as follow. At a given point of the loss function, approximate it with a elliplical parabola. Newton's method gives the direction that is directly pointing from that point to the minimum of the 'imagined' ellipical parabola. This means the GD vector in general does not point in the same direction as the gradient at that point due to $H^{-1}$. Note that if the parabola is spherical instead of elliplical, $H$ will be diagonal and the GD vector will point in the same direction of the gradient at that point. 

In the jargon of Riemannian geometry, $H^{-1}\nabla_\theta E(\theta)$ is called the *natural gradient* (to be discussed in A3) with $H$ acts as the metric tensor $g(\theta)$ of the Riemannian manifold.

#### RMSprop

\begin{align}
g(t) &= \nabla_\theta E(\theta) \\
s(t) &= \beta s(t-1) + (1-\beta) g^2(t) \\
\theta(t) &= \theta(t-1) - \alpha \frac{g(t)}{\sqrt{s(t) + \epsilon 1}}
\end{align}

RMSprop tries to update each weight by a constant magnitude $\alpha$. This is because for ordinary SGD, individual derivatives (of different weights) might be very large or very small, resulting in taking a very large or very small step size in the GD. By ensuring that each update has magnitude approximately $\alpha$ (e.g. $10^{-3}$), we ensure that each weight is changed by only a little bit in each iteration, but over many (e.g. 1000) iterations, the weights still have the opportunity to move a long distance. The learning rate $\alpha$ acts as the free parameter to the algorithm and so RMSprop is not completely 'adaptive'.

#### Adagrad

\begin{align}
g(t) &= \nabla_\theta E(\theta) \\
s(t) &= s(t-1) + g^2(t) \\
\theta(t) &= \theta(t-1) - \alpha \frac{g(t)}{\sqrt{s(t) + \epsilon 1}}
\end{align}

Adagrad is very similar to RMSprop, except instead of keeping a running exponential average of the square of the gradient, it keeps the sum of the square of the gradient. This naturally inflates the factor $\sqrt{s(t) + \epsilon 1}$ over time and therefore act as an implicit learning rate decay. [ref](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/readings/L05_normalization.pdf) note that RMSprop and ADAM is preferred to Adagrad because of the lack of implicit learning rate decay.

#### ADAM

ADAM incorporates momentum in the optimization algorithm. Momentum allows a speed up under a persistent and slightly sloped terrain, which will otherwise results in small update step in GD and SGD (first-order algorithm). This is done by keeping a running average of first and second moment of the gradient:

\begin{align}
g(t) &= \nabla_\theta E(\theta) \\
m(t+1) &= \alpha m(t) + (1-\alpha) g(t) \\
s(t+1) &= \beta s(t) + (1-\beta) g^2(t) \\
\hat m(t) &= \frac{m(t)}{1-\alpha^t} \\
\hat s(t) &= \frac{s(t)}{1-\beta^t} \\
\theta(t+1) &= \theta(t) - \eta_t \frac{\hat m(t)}{\sqrt{\hat s(t) + \epsilon}}
\end{align}

Therefore, the effective learning rate is reduced (increased) in directions where the gradients are consistently large (small).

Technically, ADAM is still a first order algorithm as it does not explicitly compute the Hessian matrix. It approximates it by computing the running average $s(t+1)$. The advantage of using ADAM over first order algorithm such as GD in high dimensional optimization is that ADAM is able to break from a saddle point, which is proliferate in high dimension, than GD.

It was mentioned in [ref](https://arxiv.org/pdf/1412.6980.pdf) and [ref](https://arxiv.org/pdf/1301.3584.pdf) that ADAM is similar to natural gradient descent (see A3) in terms of employing a preconditioner that adapts to the geometry of the data since $\hat s(t)$ is an approximation to the diagonal of the Fisher Information Matrix.

### A3 Natural gradient

[ref1](https://agustinus.kristia.de/techblog/2018/03/14/natural-gradient/), [ref2](https://andrewcharlesjones.github.io/journal/natural-gradients.html), [ref3](https://arxiv.org/pdf/1301.3584.pdf)

The natural gradient could be understood as follow. For a vanilla GD, the gradient points in the direction of largest change in the loss function per unit change in the parameter. The change is defined in terms of Euclidean distance. Natural gradient generalizes this by allowing considering any distance defined by a given metric $g(\theta)$.

Consider a loss function of a model over a domain of parameters $\pmb w \in \mathbb{R}^d$ 

\begin{equation}
\mathcal{L}(\pmb w)
\end{equation}

The goal in machine learning is to derive the optimal value for $\pmb w$ that minimize the above loss function. This can be achieved by performing gradient descent. In general, the domain of the loss function can be considered as a manifold and the loss function is a function that lives on the manifold. One can equip a metric $g$ to the manifold to give the definition of distance between two points on the manifold. 

\begin{equation}
(ds)^2 = \sum_{ij} g_{ij}(\pmb w)dw_i dw_j
\end{equation}

A manifold with an endowed metric is called a **Riemannian manifold** $\Omega(g, \mathbb{R}^d)$. Usually when considering the loss function, we choose $g = \pmb I$. In general, $g$ can be any $N \times N$ symmetric matrix. It is shown in [ref](http://www.yaroslavvb.com/papers/amari-why.pdf) that for a given metric $g$ in $\Omega$, the gradient is modified to

\begin{equation}
-g^{-1}(\pmb w) \nabla \mathcal{L}(\pmb w)
\end{equation}


#### Natural gradient with Fisher information matrix

One can define the distance between two points in the parameter space not by the actual difference in the parameter value but in the difference in the model predictions. KL divergence of the likelihood distribution under two different parameters ($\pmb \theta$ and $\pmb \theta + \pmb \delta$) can therefore be used.

\begin{equation}
\Delta s^2 = D_{KL}\big[P(y\mid x, \pmb \theta) \| P(y\mid x, \pmb \theta + \pmb \delta)\big] \label{KL_dist}\tag{Eq 2.1}
\end{equation}

for small difference in the parameter value, it is shown that KL divergence is roughly symmetric and the metric tensor can be approximated by the Fisher information matrix. 


\begin{equation}
\Delta s^2 \approx \frac{1}{2}\sum_{ij} F_{ij}\delta_i \delta_j
\end{equation}

where

\begin{equation}
F = \Big\langle\nabla_{\pmb\theta}\log P(y\mid x,\pmb\theta)\ \nabla_{\pmb\theta}\log P(y\mid x,\pmb\theta)^\top \Big\rangle_{x \sim P(x),\ y \sim P(y\mid x)}
\end{equation}

Note that the expectation is over $x$ and $\hat{y} \mid x$ (also known as true FIM) instead of over $x$ and the empirical $y$ (empirical FIM). 

The GD equation becomes

\begin{equation}
\pmb\theta_{t+1} = \pmb\theta_t - \eta F^{-1}\nabla_\theta \mathcal{L}(\pmb \theta)
\end{equation}

Therefore, for a parameter with a strong effect on the outcome probability of the model, the natural gradient will penalize movement in that direction due to the inverse of the metric. Consequently, natural gradient chooses a direction that minimizes the loss function (first-order Taylor expansion) while maintaining constant distance (second-order Taylor expansion, as measured by KL divergence) [ref](https://arxiv.org/pdf/1301.3584.pdf). 

#### Derivation of the Fisher Information Matrix using differential geometry

Traditionally, an ML model is a function $f$, controlled by some parameters $\theta$, that takes in input $x$ and produces output $y$: $f(\theta): x \to y$. In the probabilistic formulism, the function is implicitly defined. I.e. the output probability density $p(x\ \mid \theta)$ is defined instead. The set of density function forms a functional manifold $\mathcal{F}$. Each element on $\mathcal{F}$ is parameterized by $\theta \in \mathbb{R}^p$. One of the properties of manifold is the existence of a *local* map $\varphi$ that maps element on $\mathcal{F}$ to an infinite dimension vector (function) space $V$

\begin{align}
\varphi: \mathcal{F} \to V
\end{align}

the output probability density $p(x\ \mid \theta)$ lives in the vector space $V$. 

So far, the manifold does not give a sense of elevation or curvature. To equip the manifold with such properties, define a function $\mathcal{L}$ on $\mathcal{F}$ 

\begin{align}
\mathcal{L}: \mathcal{F} \to \mathbb{R}
\end{align}

$\mathcal{L}$ is a loss function. The following graph describes the relationship between the $\varphi$ map and the loss function

<img src="figures/C11/data_manifold_map.png" width=200>

The loss function can be explicitly written as 

\begin{align}
\mathcal{L} \circ \varphi^{-1}(\theta) = \langle - \ln p(z \mid \theta) \rangle_z
\end{align}

Once the loss function is defined, the distance on the manifold can be defined as the change in the loss function due to change in $\theta$

\begin{align}
\Delta s^2 = \langle (-\ln p(z \mid \theta+\delta) + \ln p(z \mid \theta))^2 \rangle_z 
\end{align}

First-order Taylor expand $\ln p(z \mid \theta+\delta)$ gives one of the forms of the Fisher Information Matrix $F$

\begin{align}
\Delta s^2 &= \delta^\top \langle (\nabla_\theta \ln p) (\nabla_\theta \ln p)^\top \rangle_z \delta \\
&= \delta^\top F \delta \label{FIM_1}\tag{Eq A3.1}
\end{align}

\ref{FIM_1} is expressed the square of the first-order derivative. Note that the FIM is **positive semi-definite** due to the square expression in \ref{FIM_1}. Another form of FIM is to express it in terms of second-order derivative. This can be obtained by using the following relationships

\begin{align}
\frac{\partial^2}{\partial \theta^2} \ln p = \frac{1}{p}\frac{\partial^2 p}{\partial \theta^2} - \Big(\frac{\partial}{\partial \theta} \ln p \Big)^2
\end{align}

and 

\begin{align}
\Big\langle\frac{1}{p}\frac{\partial^2 p}{\partial \theta^2}\Big\rangle_z &= \int dz p(z\mid \theta)\frac{1}{p(z\mid \theta)}\frac{\partial^2 p}{\partial \theta^2}\\
&= \frac{\partial^2 }{\partial \theta^2} \int dz p(z \mid \theta)\\
&= 0
\end{align}

With the above relationships, one can write \ref{FIM_1} in another form

\begin{align}
F = \langle -\nabla_\theta^2 \ln p(z \mid \theta) \rangle_z \label{FIM_2}\tag{Eq A3.2}
\end{align}

Finally, \ref{FIM_2} can also be derived by defining the distance metric to be the KL divergence, i.e. $g(\theta) = D_{KL}(p(z \mid \theta) \| p(\theta + \delta))$ and second-order Taylor expand $\ln p(z \mid \theta+\delta)$ and Taylor expand $ \ln(1+x) \approx x$ by assuming $\delta$ is small:

\begin{align}
&D_{KL}(p(z \mid \theta) \| p(\theta + \delta))\\
&= \int dz p(z \mid \theta) \ln \Big(\frac{p(z \mid \theta)}{p(z \mid \theta + \delta)} \Big)\\
&= \int dz p(z \mid \theta) \ln p(z \mid \theta) - \int dz p(z \mid \theta) \ln p(z \mid \theta + \delta)\\
&= \int dz p \ln p - \int dz p \ln \Big(p + \frac{\partial p}{\partial \theta}\delta + \frac{1}{2}\frac{\partial^2 p}{\partial \theta^2}\delta^2 + O(\delta^3)\Big)\\
&\approx \int dz p \ln p - \int dz p \ln \Big(p \Big(1 + \frac{1}{p}\frac{\partial p}{\partial \theta}\delta + \frac{1}{2p}\frac{\partial^2 p}{\partial \theta^2}\delta^2\Big)\Big)\\
&\approx -\int dz p\Big(\delta \frac{\partial}{\partial \theta}\ln p + \delta^2\frac{1}{2p}\frac{\partial^2 p}{\partial \theta^2} - 
\frac{1}{2}\Big(\delta\frac{\partial}{\partial \theta}\ln p + \delta^2\frac{1}{2p}\frac{\partial^2 p}{\partial \theta^2}\Big)^2 + O(\delta^5)\Big)\\
&= -\int dz p \Big( \delta\frac{\partial}{\partial \theta} \ln p + \delta^2\frac{1}{2p}\frac{\partial^2 p}{\partial \theta^2} - \frac{1}{2}\Big(\delta \frac{\partial^2}{\partial \theta^2} \ln p\Big)^2\Big) + O(\delta^3)\\
&\approx -\Big\langle -\frac{1}{2}\Big(\frac{\partial}{\partial \theta}\ln p\Big)^2\delta^2\Big\rangle\\
&= -\Big\langle \frac{1}{2}\frac{\partial^2}{\partial \theta^2}\ln p\Big\rangle
\end{align}

which is identical to \ref{FIM_2}. Note that, FIM is only a (second-order) approximation to KL divergence.

The significance of \ref{FIM_2} is that it can be used to approximate the Hessian matrix in any second-order optimization algorithms. The reason why \ref{FIM_2} is only an approximation of the true Hessian matrix is in the expectation function. For \ref{FIM_2}, the expectation is over the empirical distribution for the input $x$ and the distribution of the model prediction $p(y \mid x, \theta)$. For the real Hessian, the expectation is over the empirical distribution for $(x,y)$ [ref](https://arxiv.org/pdf/1412.1193.pdf). Also, from \ref{FIM_1}, FIM is PSD, which is not true for the real Hessian matrix (existence of saddle points and plateaus).

### A4 Morse Lemma
[ref](https://encyclopediaofmath.org/wiki/Morse_lemma)

#### Morse index

The Morse index of a critical point $p$ of a smooth function $f$ on a manifold $M$ is equal, by definition, to the negative index of inertia of the Hessian of $f$ at $p$ (cf. Hessian of a function), that is, the dimension of the maximal subspace of the tangent space $T_p M$ of $M$ at $p$ on which the Hessian is negative definite. Such maximal subspace is spanned by negative eigenvalue eigenvectors at point $p$. Therefore, it is simply the number of negative eigenvalue at that point.

Note that the above interpretation defines a surface as a function on a simple manifold $M$ (e.g. $\mathbb{R}^N$). Therefore if the $M=\mathbb{R}^N$, the tangent space of a point $p$ exists on the manifold $M$ itself. Another way to define a surface is the manifold (e.g. a sphere) itself, with a local mapping function $\phi$ to a Cartesian space $\mathbb{R}^N$. The link from the second to the first interpretation is established by the Nash embedding theorem.

#### Morse Lemma



### A5 Mathematical formalism of Transformers [ref](https://arxiv.org/pdf/2312.10794.pdf)

### A6 Other resources
[CSC2541 Winter 2021 Topics in Machine Learning: Neural Net Training Dynamics](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/)