### Gradient descent in the linear network

Load the MNIST dataset. It has 70000 rows and 784 columns.

In [560]:
import pickle

with open('mnist.pkl', 'rb') as f:
    mnist = pickle.load(f)
X = mnist['X']
Y = mnist['Y']

Whiten $X$ and take the first 100 components.

In [561]:
import numpy as np

X = X - X.mean(axis=0)
U, S, W = np.linalg.svd(X.T @ X)
X = X @ W[:100].T / np.sqrt(s[:100])
X /= X.std(axis=0)

Define the multi-layer linear network tuned by SGD with squared error loss.

In [562]:
def mapping(layers, X, Y=None, lr=0):
    g = [X]
    for W in layers:
        X = X @ W.T
        g.append(X)
    out = g.pop()
    err = 0 * out if Y is None else Y - out
    dLdg = -2 * err
    for W in layers[::-1]:
        X = g.pop()
        dLdW = np.mean(dLdg[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0)
        dLdg = np.sum(dLdg[:, :, np.newaxis] * W[np.newaxis, :, :], axis=1)
        W -= lr * dLdW
    return out

L = lambda fX, gX: np.mean(np.sum((fX - gX) ** 2, axis=1), axis=0)

Initialize a two-layer network with random weights.

In [563]:
import functools
np.random.seed(0)
N1, N2, N3 = (100, 10, 10)

def random_W(M, N):
    n = max(M, N)
    Q, R = np.linalg.qr(np.random.randn(n, n))
    return Q[:M, :N]

W21 = random_W(N2, N1)
W32 = random_W(N3, N2)
g = functools.partial(mapping, [W21, W32])

Train the network until loss and accuracy plateau.

In [564]:
gX = g(X)
acc = np.mean(np.argmax(gX, axis=1) == np.argmax(Y, axis=1))
print(f'starting loss: {L(gX, Y)}, accuracy: {acc}')
for epoch in range(10):
    start = 0
    for batch in range(70):
        end = start + 100
        step(X[start:end], Y[start:end], lr=0.01)
        start = end
    gX = g(X)
    acc = np.mean(np.argmax(gX, axis=1) == np.argmax(Y, axis=1))
    print(f'loss after epoch {epoch}: {L(gX, Y)}, accuracy: {acc}')

starting loss: 10.659795190762066, accuracy: 0.13338571428571427
loss after epoch 0: 1.3605439130761887, accuracy: 0.36551428571428574
loss after epoch 1: 0.8704074676266697, accuracy: 0.5845285714285714
loss after epoch 2: 0.7170727481955943, accuracy: 0.7106428571428571
loss after epoch 3: 0.6387518666852239, accuracy: 0.7727571428571428
loss after epoch 4: 0.5922787330623296, accuracy: 0.8034285714285714
loss after epoch 5: 0.5640342529982761, accuracy: 0.8213857142857143
loss after epoch 6: 0.5470163495276731, accuracy: 0.8305857142857143
loss after epoch 7: 0.5368795570967739, accuracy: 0.8362285714285714
loss after epoch 8: 0.5308914641079793, accuracy: 0.8401714285714286
loss after epoch 9: 0.5273748472260247, accuracy: 0.8427714285714286


Observe that when network is solved, $W^{32} W^{21} X^T X \to Y^T X$, which is to say $\text{cov}[g(X), X] \to \text{cov}[Y, X]$.

In [566]:
CYX = Y.T @ X / len(X)
C11 = X.T @ X / len(X)
C31 = W32 @ W21 @ C11
np.sqrt(np.mean((CYX - C31) ** 2))

0.004038315594681743

Recall that $X$ is whitened (i.e. $\text{cov}[X] = I$). This means that $W^{32} W^{21} = \text{cov}[Y, X]$ is a fixed point of the learning dynamics.

In [567]:
np.sqrt(np.mean((CYX - W32 @ W21) ** 2))

0.004038315594681735

### Connectivity modes

Consider the singular value decomposition $W^{32} W^{21} = C^{31} = U^{33} S^{31} {V^{11}}^T$.

In [597]:
U33, s, V11T = np.linalg.svd(C31)
S31 = np.zeros((N3, N1))
S31[:N3, :N3] = np.diag(s)

Define $\overline{W^{32}}$ and $\overline{W^{21}}$ as a change of variables on the synaptic weight spaces $W^{32}$ and $W^{21}$:

$\overline{W^{32}} = {U^{33}}^T W^{32}$

$\overline{W^{21}}^T = W^{21} V^{11}$

In [599]:
W21bar = V11T @ W21.T
W32bar = U33.T @ W32

Then in the new synaptic weight space, the fixed point $W^{32} W^{21} = C^{31}$ becomes

$\overline{W^{32}} \overline{W^{21}}^T = S^{31}$

In [603]:
np.sqrt(np.mean((S31 - U33.T @ W32 @ W21 @ V11T) ** 2))

0.03147067570905002

Now, the gradient descent rules can be written in terms of $C^{31}$ without reference to the underlying dataset $X$.