In [1]:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(0)

# Programming a linear network to classify MNIST digits

### The target function

We start with a vector-valued **target function** $f$ we wish to approximate:
$$
f: \vec x \mapsto \vec y \quad \vec x \in \mathbb{R}^N, \vec y \in \mathbb{R}^M
$$

Here we specify an $f$ by sampling a random "teacher" $W^* \in \mathbb R^{M \times N}$. $W^*$ sends a $N$-vector $\vec x$ to an $M$-vector $W \vec x$ by left multiplication, or a $P\times N$ matrix $X$ with $N$-vectors along its rows to a $P\times M$ matrix $X @ W.T$ with $M$-vectors along its rows. This target function is a linear mapping, meaning that in principle a linear neural network can fit it perfectly.

In [4]:
M = 10
N = 100
Wstar = np.random.randn(M, N)
f = lambda X: X @ Wstar.T

This gives rise to a **dataset** $\mathcal D$ of examples from $f$:
$$
\mathcal D = \{(\vec x^\mu, \vec y^\mu)\} \quad \mu = 1 \dots P
$$
We sample $X$ from a standard normal and apply $f$ to make $Y$.

In [5]:
P = 1000
X = Xall[:1000]
Y = f(X)

print(f'X: {X.shape}, Y: {Y.shape}')

X: (1000, 100), Y: (1000, 10)


### The network

Gradient descent operates on a **tunable map** $g$ which we will use to approximate $f$:
$$
g: \vec x \mapsto W \vec x \quad W \in \mathbb{R}^{M \times N}
$$

In [24]:
class Network:
    def __init__(self, N1, N2, N3):
        self.W21 = np.random.randn(N2, N1)
        self.W32 = np.random.randn(N3, N2)
        self.g1 = None
        self.g2 = None
        self.g3 = None

    def __call__(self, X):
        self.g1 = X
        self.g2 = self.g1 @ self.W21.T
        self.g3 = self.g2 @ self.W32.T
        return self.g3

Here we use a two-layer linear network

$g(\vec x) = W^{32} W^{21} \vec x$ parameterized by $W^{21} \in \mathbb{R}^{N_2 \times N_1}$, $W^{32} \in \mathbb{R}^{N_3 \times N_2}$.

The network has $N_1$ input units, $N_2$ hidden units and $N_3$ output units.

In [25]:
g = Network(N, 20, M)
gX = g(X)

print(f'W21: {g.W21.shape}, W32: {g.W32.shape}, g(X): {gX.shape}')

W21: (20, 100), W32: (10, 20), g(X): (1000, 10)


### The objective

Gradient descent employs an **objective function** $\mathcal L$ which we minimize with respect to $g$ over $\mathcal D$:
$$
\mathcal L_\mathcal D(g) = \sum_{\mu = 1}^P \lVert g(x) - f(x) \rVert_2^2
$$
We use squared error across output units, averaged over batches of size $P$.

In [26]:
E = lambda fX, gX: fX - gX
L = lambda fX, gX: np.sum(E(fX, gX) ** 2) / P

loss = L(Y, gX)
print(f'loss: {loss}, error per unit: {np.sqrt(loss) / M}')

loss: 22516.691148140602, error per unit: 15.005562684598203


Minimizing $\mathcal L$ involves repeatedly computing its **gradients** with respect to model parameters $W$:
$$
\nabla \mathcal L_\mathcal D(W) = \begin{bmatrix}
    \frac{\partial \mathcal L}{\partial W_{11}} & \cdots & \frac{\partial \mathcal L}{\partial W_{M1}} \\
    \vdots & \ddots & \vdots \\
    \frac{\partial \mathcal L}{\partial W_{1N}} & \cdots & \frac{\partial \mathcal L}{\partial W_{MN}}
\end{bmatrix}
$$

In [27]:
dgdW = lambda X: X[:, np.newaxis, :] # (X rows, W output, W input)
dgdX = lambda W: W[np.newaxis, :, :] # (X rows, W output, X cols)
dLdW = lambda dLdg, X: np.mean(dLdg[:, :, np.newaxis] * dgdW(X), axis=0) # (W output, W input)
dLdX = lambda dLdg, W: np.sum(dLdg[:, :, np.newaxis] * dgdX(W), axis=1) # (X rows, X cols)

def grads(fX, g):
    dLdg3 = -2 * E(fX, g.g3) / P # gradient from loss to last layer
    dLdW32 = dLdW(dLdg3, g.g2)   # gradient from loss to W32
    dLdg2 = dLdX(dLdg3, g.W32)   # gradient from loss to hidden layer
    dLdW21 = dLdW(dLdg2, g.g1)   # gradient from loss to W21
    return dLdW21, dLdW32

dLdW21, dLdW32 = grads(Y, g)
print(f'grad W21 shape: {dLdW21.shape}, rms: {np.sqrt(np.mean(dLdW21 ** 2))}')
print(f'grad W32 shape: {dLdW32.shape}, rms: {np.sqrt(np.mean(dLdW32 ** 2))}')

grad W21 shape: (20, 100), rms: 0.042313856308547684
grad W32 shape: (10, 20), rms: 0.22997267355484222


### The procedure

To train, we repeatedly apply the **update rule** for one step of optimization:
$$
W_{t + 1} = W_t - \lambda \cdot \nabla \mathcal L_\mathcal D(W_t)
$$

In [28]:
def update(g, Y, lr=1):
    dLdW21, dLdW32 = grads(Y, g)
    g.W21 -= lr * dLdW21
    g.W32 -= lr * dLdW32

The update rule perturbs the weights in a way that reduces the loss on each step.

In [29]:
loss1 = L(Y, g(X))
update(g, Y)
loss2 = L(Y, g(X))
print(f'loss before: {loss1}, loss after: {loss2}')

loss before: 22516.691148140602, loss after: 11344.605325611119


A full training run is shown below.

In [30]:
import time

In [31]:
nsteps = 10000

print(f'initial loss: {L(Y, g(X)):.6f}')
t0 = time.time()
for i in range(1, nsteps + 1):
    g(X)
    update(g, Y, lr=10)
    if (i + 1) % 1000 == 0:
        print(f'step {i + 1} loss: {L(Y, g(X)):.6f} ({time.time() - t0:.1f} sec)')

initial loss: 11344.605326
step 1000 loss: 1.009923 (3.7 sec)
step 2000 loss: 0.764607 (7.5 sec)
step 3000 loss: 0.593271 (11.0 sec)
step 4000 loss: 0.475792 (14.5 sec)
step 5000 loss: 0.393624 (18.2 sec)
step 6000 loss: 0.335833 (21.6 sec)
step 7000 loss: 0.293393 (25.1 sec)
step 8000 loss: 0.260080 (28.7 sec)
step 9000 loss: 0.232293 (32.2 sec)
step 10000 loss: 0.208130 (35.7 sec)


### The outcome

Given enough time, the network can perfectly fit the training data (it sends the loss to 0).

We find that the network mapping $W^{32} W^{21}$ approximates $C^{31}$ after training.

In [33]:
print(f'|C31|: {np.sqrt(np.sum(C31 ** 2))}')
print(f'|W32|: {np.sqrt(np.sum(g.W32 ** 2))}')
print(f'|W21|: {np.sqrt(np.sum(g.W21 ** 2))}')
print(f'|C31 - W32.W21|: {np.sqrt(np.sum((C31 - g.W32 @ g.W21) ** 2))}')

|C31|: 1.3984750918285602
|W32|: 0.6781734253106437
|W21|: 26.073818835182475
|C31 - W32.W21|: 0.5357242405823567


### The setting

Load whitened MNIST

In [2]:
import pickle
with open('mnist.pkl', 'rb') as f:
    data = pickle.load(f)
Y = data['Y'] * 2 - 1
X = (data['X'] - data['mu_whiten']) @ data['W_whiten'][:, :100]

Compute C31 = cov(Y, X)

In [3]:
C11 = X.T @ X / len(X)
C31 = Y.T @ X / len(Y)
C33 = Y.T @ Y / len(Y)
Xall = X.copy()
Yall = Y.copy()

Argument: the best a linear network can hope to do is make W32 W21 = C31.