In [1]:
import numpy as np

In [2]:
def forward(X, Ws):
    activity = [X]
    for W in Ws:
        X = X @ W.T
        activity.append(X)
    return activity

def backward(Y, Ws, activity, lr):
    E = Y - activity.pop()
    dLdg = -2 * E
    for W in Ws[::-1]:
        X = activity.pop()
        dLdW = np.sum(dLdg[:, :, np.newaxis] * X[:, np.newaxis, :], axis=0)
        dLdg = np.sum(dLdg[:, :, np.newaxis] * W[np.newaxis, :, :], axis=1)
        W -= lr * dLdW
    return np.sum(E ** 2)

def network(Ws, X, Y=None, lr=0):
    if Y is None:
        return forward(X, Ws)[-1]
    return backward(Y, Ws, forward(X, Ws), lr)

In [3]:
from functools import partial

In [4]:
N1 = 100
N2 = 32
N3 = 10
P = 1000

np.random.seed(0)
Wstar = np.random.randn(N3, N1) / np.sqrt(N1)

X = np.random.randn(P, N1)
X0 = X - X.mean(axis=0)
_, s, VT = np.linalg.svd(X0.T @ X0)
X = X0 @ VT.T / np.sqrt(s)
Y = X @ Wstar.T

W21 = np.random.randn(N2, N1) / np.sqrt(N1)
W32 = np.random.randn(N3, N2) / np.sqrt(N2)
g = partial(network, [W21, W32])

In [5]:
g(X, Y)

21.154511551381834

In [6]:
%%time
for i in range(150):
    L = g(X, Y, lr=0.1)
    if i % 10 == 0:
        print(i, L)

0 21.154511551381834
10 0.7755734772305993
20 0.04879232359184037
30 0.0034612062519514164
40 0.0002848261540952846
50 2.5556794580694936e-05
60 2.409958043060305e-06
70 2.3390660997556447e-07
80 2.3086055192145785e-08
90 2.300735090816294e-09
100 2.3057644184327465e-10
110 2.3182863738803365e-11
120 2.3352330815355485e-12
130 2.354845202984483e-13
140 2.3761068625380395e-14
CPU times: user 9.57 s, sys: 6.46 s, total: 16 s
Wall time: 2.15 s


In [7]:
np.allclose(W32 @ W21, Wstar)

True