# The dynamics of learning in deep linear networks

In [175]:
from functools import reduce
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

Here we implement the deep linear network trained by gradient descent and test the claims made in "Exact solution to the nonlinear dynamics of learning in deep linear neural networks" by Andrew Saxe, James McClelland, and Surya Ganguli.

# Part 1: setting the stage

### Target function

We want a linear network to approximate a **target function** $f$ which maps $N$-vectors to $M$-vectors. 

To program the network we need a **dataset** of examples of the input-output mapping $\mathcal D = x^\mu, y^\mu$ for $\mu = 1, 2 \dots P$.

We take $X$ as the first 100 (whitened) columns of the first 1000 MNIST examples, and $Y$ as a one-hot coded category label.

So $N = 100$, $M = 10$, and $P = 1000$.

In [148]:
M = 10
N = 100
P = 50000
X, Y = pd.read_pickle('mnist.pkl')
X = X[:P]
Y = Y[:P]
X = X - X.mean(axis=0)
U, S, VT = np.linalg.svd(X.T @ X)
X = X @ VT[:N].T
X = X / np.std(X, axis=0)

### Forward pass

The architecture of an $L$-layer linear network is described by the layer sizes $(N_0, N_1 \dots N_L)$.

In this convention, layer $0$ assumes the value of the input vector and layer $L$ assumes the output values.

In [162]:
# 2-layer linear network with 100 hidden units
Nl = [N, 32, M]

The network computes a **programmable mapping** $g$ parameterized by the matrices $(W_1, W_2 \dots W_L)$.

$W^l$ is a $N_l \times N_{l - 1}$ matrix that maps vectors from layer $l - 1$ to $l$.

In [163]:
# initialize weights as norm-preserving random projections
Wl = []
for Nin, Nout in zip(Nl[:-1], Nl[1:]):
    W = np.random.randn(Nout, Nin)
    W /= np.linalg.norm(W, axis=1, keepdims=True)
    Wl.append(W)

The network computes the function

$$g(X) = X {W_1}^T {W_2}^T \dots {W_L}^T$$

when applied to matrix $X$ with input vectors $x$ on its rows.

In [164]:
# forward pass computes internal activation and output
def forward(X, Wl):
    Xl = [X]
    for W in Wl:
        X = X @ W.T
        Xl.append(X)
    gX = Xl[-1]
    return gX, Xl

In [165]:
# map X to g(X)
gX = forward(X, Wl)[0]

We want to optimize $g$ for approximating the target mapping $f$ over the dataset $\mathcal D$.

### Optimization objective

To do this we define an **objective function $\mathcal L$** that outputs a quantity called the loss.

The goal of gradient descent is to find $(W^1, W^2 \dots W^L)$ that minimize the loss given $\mathcal D$.

As a loss function we minimize squared error, where error $E = f(X) - g(X)$:

In [166]:
# loss measures model error over the training data
def loss(fX, gX):
    return np.mean((fX - gX) ** 2)

In [167]:
# compute the loss
loss(Y, gX)

1.1043516264560895

### Backward pass

Observe that $\frac{dL}{dE} = 2 \cdot E$ and $\frac{dE}{dg} = -1$.

Then dropping a constant $2$ and denoting activation at the output layer $X_L$,

$$\frac{d\mathcal L}{dX_L} = g(X) - f(X).$$

Since layer $l$ computes $X_{l-1} {W_l}^T$, the gradient of its output with respect to its input is $W^T$.

Then denoting activations at layer $l$ as $X_l$,

$$\frac{dX_l}{dX_{l - 1}} = {W_l}^T$$

Likewise, the gradient of its weight matrix ${W_l}^T$ with respect its output is $X_{l-1}$, the layer's input:

$$\frac{dX_l}{{dW_l}^T} = {X_{l - 1}}.$$

Using these equations we can propagate gradients backward from the loss to any part of the network.

We stop at each layer $l$ and compute the gradients of $\mathcal L$ with respect $W_l$ using that layer's activation:

$$\frac{d\mathcal L}{dW_l^T} = \frac{d\mathcal L}{dX_L} \frac{dX_L}{dX_{L-1}} \dots \frac{dX_{l+1}}{dX_l} \frac{dX_l}{dW_l^T}$$

In [168]:
# backward pass computes gradients of loss with respect to weights
def backward(fX, Xl):
    E = fX - Xl.pop()
    dLdg = -2 * E
    dLdWl = []
    for W in Wl[::-1]:
        X = Xl.pop()
        dLdW = np.sum(dLdg[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0)
        dLdg = np.sum(dLdg[:, :, np.newaxis] * W[np.newaxis, :, :], axis=1)
        dLdWl.append(dLdW)
    return dLdWl[::-1]

In [169]:
# run the backward pass
Xl = forward(X, Wl)[-1]
dLdWl = backward(Y, Xl)

### Stepwise update

Gradient descent iteratively updates $(W^1, W^2 \dots W^L)$ in a manner that reduces $\mathcal L$.

This is achieved by applying the **update rule**

$$d W_l = -\lambda \cdot \frac{d\mathcal L}{dW_l}$$

where $\lambda$ is a small learning rate.

In [170]:
def update(Wl, dLdWl, lr):
    for W, dLdW in zip(Wl, dLdWl):
        W -= lr * dLdW

Each step of gradient descent updates each weight matrix slightly in a direction opposite to its gradient with respect to the loss.

This reduces $g$'s actual error at approximating $f$ on $X$ and expected error at approximating $f$ on an input distribution like $X$.

In [330]:
# initialize weights as norm-preserving random projections
Wl = []
for Nin, Nout in zip(Nl[:-1], Nl[1:]):
    W = np.random.randn(Nout, Nin)
    W /= np.linalg.norm(W, axis=1, keepdims=True)
    Wl.append(W)
t = 0

### A step of learning

We may compute the loss prior to any training:

In [331]:
gX = forward(X, Wl)[0]
print(f"loss(t={t}): {loss(Y, gX):.6f}")

loss(t=0): 0.936852


One step of gradient descent looks like this:

In [332]:
dLdWl = backward(Y, forward(X, Wl)[-1])
update(Wl, dLdWl, lr=1e-6)
t += 1

Notice the loss has descreased:

In [333]:
print(f"loss(t={t}): {loss(Y, forward(X, Wl)[0]):.6f}")

loss(t=1): 0.594968


### Solving a task

In [334]:
def accuracy(fX, gX):
    return np.mean(np.argmax(fX, axis=1) == np.argmax(gX, axis=1))

Let's see if the network can learn to classify MNIST digits using the $\arg \max$ decision rule.

Before any training, its accuracy is around 10% (chance):

In [335]:
print(f"accuracy(t=0): {accuracy(Y, gX)}")

accuracy(t=0): 0.11006


After training for 1 step, the accuracy slightly improved:

In [336]:
print(f"accuracy(t=1): {accuracy(Y, forward(X, Wl)[0])}")

accuracy(t=1): 0.12794


Now, we train the network for 9 more steps:

In [337]:
for i in tqdm(range(9)):
    dLdWl = backward(Y, forward(X, Wl)[-1])
    update(Wl, dLdWl, lr=1e-6)
    t += 1
print(f"accuracy(t={t}): {accuracy(Y, forward(X, Wl)[0])}")

  0%|          | 0/9 [00:00<?, ?it/s]

accuracy(t=10): 0.38976


And then 90 more steps:

In [338]:
for i in tqdm(range(90)):
    dLdWl = backward(Y, forward(X, Wl)[-1])
    update(Wl, dLdWl, lr=1e-6)
    t += 1
print(f"accuracy(t={t}): {accuracy(Y, forward(X, Wl)[0])}")

  0%|          | 0/90 [00:00<?, ?it/s]

accuracy(t=100): 0.82956


After 100 gradient descent steps the accuracy above 80%. Therefore **the network learns** to solve a task.

---

# Part 2: ???

The authors write that due to the linearity of the network, it can be trained by pre-computing $\text{cov}[f(X), g(X)]$ and running the forward-backward pass on it instead of the full dataset of training examples.

In [340]:
CXX = X.T @ X / len(X)
CYX = Y.T @ X / len(Y)
CYY = Y.T @ Y / len(Y)

Using these we can write an equivalent **covariance update rule**:

In [342]:
def backward(Ws, CXX, CYX, lr):
    dLdg = compose(Ws) @ CXX - CYX
    for i, W in enumerate(Ws):
        if i == 0:
            dLdW = compose(Ws[1:]).T @ dLdg
        elif i == len(Ws) - 1:
            dLdW = dLdg @ compose(Ws[:-1]).T
        else:
            dLdW = compose(Ws[i + 1:]).T @ dLdg @ compose(Ws[:i]).T
        W -= lr * dLdW