In [1]:
import numpy as np
from functools import reduce, partial
import pickle

In [2]:
def rand_W(M, N):
    return np.random.randn(M, N) / np.sqrt(N)

def orth_W(M, N):
    W = np.random.randn(N, N)
    Q, R = np.linalg.qr(W)
    return Q[:M, :N]

def init_weights(Ns, orth=False):
    Ws = []
    for M, N in zip(Ns[1:], Ns[:-1]):
        W = orth_W(M, N) if orth else rand_W(M, N)
        Ws.append(W)
    return Ws

In [3]:
def compose(Ws):
    return reduce(np.dot, Ws[::-1])

def forward(Ws, X):
    return X @ compose(Ws).T

def backward(Ws, C11, C31, lr):
    dLdg = compose(Ws) @ C11 - C31
    for i, W in enumerate(Ws):
        if i == 0:
            dLdW = compose(Ws[1:]).T @ dLdg
        elif i == len(Ws) - 1:
            dLdW = dLdg @ compose(Ws[:-1]).T
        else:
            dLdW = compose(Ws[i + 1:]).T @ dLdg @ compose(Ws[:i]).T
        W -= lr * dLdW

In [4]:
class Network:
    def __init__(self, Ws, X, Y):
        self.Ws = Ws
        self.C11 = X.T @ X / len(X)
        self.C31 = Y.T @ X / len(X)

    def map(self, X):
        return forward(self.Ws, X)

    def loss(self, X, Y):
        return np.mean((Y - self.map(X)) ** 2)

    def update(self, lr):
        return backward(self.Ws, self.C11, self.C31, lr)

In [5]:
with open('mnist.pkl', 'rb') as f:
    mnist = pickle.load(f)
X = mnist['X'].astype('float32')
Y = mnist['Y'].astype('float32')
X -= X.mean(axis=0)
std = X.std(axis=0)
std[std == 0] = 1
X /= std
Y /= Y.std(axis=0)
train = np.arange(50000)
test = np.arange(50000, 60000)

In [6]:
np.random.seed(0)
N = [784] + 100 * [100] + [10]
W1 = init_weights(N)
W2 = init_weights(N, orth=True)
g1 = Network(W1, X[train], Y[train])
g2 = Network(W2, X[train], Y[train])

In [7]:
for i in range(100):
    if i % 1 == 0:
        print(f"{i}: g1 train {g1.loss(X[train], Y[train]):.4f}, valid {g1.loss(X[test], Y[test]):.4f};",
              f"g2 train {g2.loss(X[train], Y[train]):.4f} valid {g2.loss(X[test], Y[test]):.4f}")
    g1.update(0.00005)
    g2.update(0.001)
    i += 1

0: g1 train 2.2702, valid 2.1648; g2 train 2.0891 valid 1.9865
1: g1 train 1.5745, valid 1.5285; g2 train 1.6279 valid 1.5437
2: g1 train 1.3667, valid 1.3406; g2 train 1.4212 valid 1.3476
3: g1 train 1.2772, valid 1.2581; g2 train 1.2975 valid 1.2315
4: g1 train 1.2340, valid 1.2186; g2 train 1.2127 valid 1.1527
5: g1 train 1.2079, valid 1.1951; g2 train 1.1501 valid 1.0949
6: g1 train 1.1900, valid 1.1788; g2 train 1.1016 valid 1.0503
7: g1 train 1.1766, valid 1.1666; g2 train 1.0627 valid 1.0146
8: g1 train 1.1663, valid 1.1572; g2 train 1.0306 valid 0.9853
9: g1 train 1.1580, valid 1.1496; g2 train 1.0036 valid 0.9606
10: g1 train 1.1512, valid 1.1433; g2 train 0.9806 valid 0.9395
11: g1 train 1.1454, valid 1.1380; g2 train 0.9607 valid 0.9213
12: g1 train 1.1404, valid 1.1334; g2 train 0.9432 valid 0.9053
13: g1 train 1.1361, valid 1.1294; g2 train 0.9278 valid 0.8911
14: g1 train 1.1323, valid 1.1259; g2 train 0.9141 valid 0.8785
15: g1 train 1.1289, valid 1.1227; g2 train 0.9017