In [1]:
%pylab inline
import numpy as np
from sklearn import preprocessing
from keras.datasets import mnist
from keras.utils import np_utils
np.set_printoptions(suppress=True)
(X_train, y_train), (X_test, y_test) = mnist.load_data()

def preproc(X_train, y_train):
    X = X_train
    X = X.reshape(X.shape[0], X.shape[1] * X.shape[2]).astype("float32")
    #X = X/255
    X = preprocessing.scale(X)
    Y = np_utils.to_categorical(y_train, 10)
    return X, Y

X, Y = preproc(X_train, y_train)
Xt, Yt = preproc(X_test, y_test)



Populating the interactive namespace from numpy and matplotlib


In [2]:
# add biases?
def update(y0):
    def dtanh(x):
        # pass in tanh
        return 1 - (np.tanh(x)**2)
    
    # forward pass
    # L vector matrix multiplies, L vector vector adds, L-1 NL
    net = []
    y = [y0]
    for i in range(len(W)):
        net.append(np.dot(y[i], W[i]) + B[i])
        if i != len(W)-1:   # last layer is linear
            y.append(np.tanh(net[-1]))
        else:
            y.append(net[-1])

    # squared error derivative, (computed - target)
    err = y[-1]-y[0]
    e = [2*(err)]
    
    # backward pass
    # L-1 vector matrix multiplies, L-1 dNL
    for i in range(len(W)-2, -1, -1):
        e = [np.dot(W[i+1], e[0]) * dtanh(net[i])] + e
    
    # update the weights and biases
    gW, gB = [], []
    for i in range(len(W)):
        gW.append(np.outer(y[i], e[i]))
        gB.append(e[i])

    return np.mean((err)**2), gW, gB

In [42]:
np.random.seed(1337)
init = 0.08

#sz = [784, 256, 32, 256, 784]
sz = [32,32]
W, B = [], []

X = np.zeros((32,32)).astype(np.float32)
for i in range(32):
    X[i,i] = 1.0

for i in range(len(sz)-1):
    W.append(np.random.uniform(size=(sz[i], sz[i+1]), low=-init, high=init).astype(np.float32))
    B.append(np.random.uniform(size=(sz[i+1]), low=-init, high=init).astype(np.float32))

In [49]:
lr = 0.1
mom = 0.9

# run an epoch
err = []
momW, momB = [], []
for j in range(len(W)):
    momW.append(np.zeros(W[j].shape).astype(np.float32))
    momB.append(np.zeros(B[j].shape).astype(np.float32))
    
minibatch_size = 4

for ep in range(10):
    for i in range(0,X.shape[0],minibatch_size):
        # do first one
        terr, gW, gB = update(X[i])
        err.append(terr)

        # do rest in minibatch
        for k in range(1, minibatch_size):
            terr, tgW, tgB = update(X[i+k])
            err.append(terr)
            for j in range(len(W)):
                gW[j] += tgW[j]
                gB[j] += tgB[j]

        for j in range(len(W)):
            #print np.max(gB[j]), np.argmax(gB[j]), B[j][np.argmax(gB[j])]

            gW[j] = np.clip(gW[j], -1, 1)
            gB[j] = np.clip(gB[j], -1, 1)

            updW = momW[j]*mom - gW[j]*(lr/minibatch_size)
            updB = momB[j]*mom - gB[j]*(lr/minibatch_size)

            W[j] += updW
            B[j] += updB

            momW[j] = updW
            momB[j] = updB

        if np.isnan(err[-1]):
            print "FAILED AT",i, err[-10:]
            break

        if (i % 2000) == 0:
            sys.stdout.write("%6d: %f\r\n" % (i, np.mean(err)))
            sys.stdout.flush()


     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000
     0: 0.000000


In [46]:
lr = 0.001

# run an epoch
for ep in range(10):
    err = []
    for i in range(0,X.shape[0]):
        terr, gW, gB = update(X[i])
        err.append(terr)

        for j in range(len(W)):
            W[j] -= gW[j]*lr
            B[j] -= gB[j]*lr
            pass

        if np.isnan(err[-1]):
            print "FAILED AT",i, err[-10:]
            break
    sys.stdout.write("%6d: %f\r\n" % (ep, np.mean(err)))
    sys.stdout.flush()

     0: 0.029202
     1: 0.029080
     2: 0.028958
     3: 0.028838
     4: 0.028719
     5: 0.028601
     6: 0.028483
     7: 0.028367
     8: 0.028251
     9: 0.028136


In [None]:
for i in range(1000):
    for j in range(100):
        update(X[j])
    
    
update(X[0]), update(X[1])

In [50]:
B[0]

array([ 0.02711952,  0.02568256,  0.03881054,  0.04218581,  0.02783629,
        0.02157393,  0.04471104,  0.01391777,  0.02890016,  0.02981861,
        0.03383147,  0.01474664,  0.02984582,  0.03125127,  0.03090135,
        0.02411032,  0.03365911,  0.03792664,  0.03710413,  0.02560797,
        0.02848255,  0.01875023,  0.02629327,  0.02598338,  0.03478689,
        0.02178761,  0.03270526,  0.03306433,  0.03505816,  0.027119  ,
        0.04663128,  0.00851507], dtype=float32)