# 誤差逆伝播法

In [1]:
import numpy as np
from misc.neuralnet import TwoLayerNet
from misc.mnist import load_train_data

X_train, y_train = load_train_data(True)
X_train = X_train.astype(np.float_) / 255

network = TwoLayerNet(784, 50, 10)
X_batch = X_train[:3]
y_batch = y_train[:3]

grad_numerical = network.numerical_gradient(X_batch, y_batch)
grad_backprop = network.gradient(X_batch, y_batch)
for key in grad_numerical.keys():
    diff = np.average(np.abs(grad_backprop[key] - grad_numerical[key]))
    print(key + ':' + str(diff))

W1:3.9698151289792513e-10
b1:2.4430850746818136e-09
W2:7.306421356617458e-09
b2:1.395481852198288e-07


In [2]:
import numpy as np
from misc.neuralnet import TwoLayerNet
from misc.mnist import load_train_data, load_test_data

n_iter = 10000
batch_size = 100
alpha = 0.01

X_train, y_train = load_train_data(True)
X_test, y_test = load_test_data(True)
X_train = X_train.astype(np.float_) / 255
X_test = X_test.astype(np.float_) / 255

network = TwoLayerNet(784, 50, 10)
for k in range(n_iter):
    batch_mask = np.random.choice(X_train.shape[0], batch_size)
    X_batch = X_train[batch_mask]
    y_batch = y_train[batch_mask]
    
    grad = network.gradient(X_batch, y_batch)
    for key in grad.keys():
        network.params[key] -= alpha * grad[key]
        
    if k % (X_train.shape[0] / batch_size) == 0:
        print('%5d: %12.8f (%.8f)' % (k, network.accuracy(X_train, y_train) * 100, network.loss(X_train, y_train)), flush=True)
        
network.accuracy(X_test, y_test) * 100

    0:   5.10833333 (2.30298820)
  600:  50.59333333 (1.99929746)
 1200:  79.00666667 (0.89139679)
 1800:  84.57333333 (0.59263677)
 2400:  87.14000000 (0.48477820)
 3000:  88.26166667 (0.43016099)
 3600:  89.05833333 (0.39641585)
 4200:  89.56166667 (0.37302795)
 4800:  89.89833333 (0.35622458)
 5400:  90.37166667 (0.34244861)
 6000:  90.68333333 (0.33062081)
 6600:  90.93166667 (0.32133308)
 7200:  91.15333333 (0.31286099)
 7800:  91.38666667 (0.30429227)
 8400:  91.57166667 (0.29803324)
 9000:  91.76500000 (0.29135719)
 9600:  91.95000000 (0.28542555)


92.36999999999999