# Trainer による訓練

In [21]:
# coding: UTF-8

import chainer
from chainer import Variable, Chain, optimizers
import chainer.links as L
import chainer.functions as F

# Trainer用
from chainer.datasets import tuple_dataset
from chainer import training, iterators
from chainer.training import extensions

import numpy as np
from sklearn import datasets # Scikit Learn にIris dataset が含まれている

In [22]:
# Iris データの読み込み
iris_data = datasets.load_iris()

In [23]:
x = iris_data.data.astype(np.float32)
t = iris_data.target
n = t.size

In [24]:
# 教師データの下処理
t_matrix = np.zeros(3 * n).reshape(n, 3).astype(np.float32)
for i in range(n):
    t_matrix[i, t[i]] = 1.0

In [25]:
# 訓練用データとテスト用データ 半分が訓練用データで残りがテスト用データ
indexes = np.arange(n)
indexes_train = indexes[indexes%2 != 0]
indexes_test = indexes[indexes%2 == 0]

x_train = x[indexes_train, :] # 訓練用 入力
t_train = t_matrix[indexes_train, :] # 訓練用 正解
x_test = x[indexes_test, :] # テスト用 入力
t_test = t[indexes_test] # テスト用 正解

# tuple_dataset を使う
train = tuple_dataset.TupleDataset(x_train, t_train)
#x_train_v = Variable(x_train)
#t_train_v = Variable(t_train)
x_test_v = Variable(x_test)

In [26]:
# Chain の記述
class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1 = L.Linear(4, 6),
            l2 = L.Linear(6, 6),
            l3 = L.Linear(6, 3),
        )
    
    # __call_追加
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x), t)
        
    def predict(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        h3 = self.l3(h2)
        return h3

In [27]:
# モデルとOptimizerの設定
model = IrisChain()
optimizer = optimizers.Adam()
optimizer.setup(model)

In [28]:
# 学習

#for i in range(10000):
#    
#    model.cleargrads()
#    y_train_v = model.predict(x_train_v)
#    
#    # 損失関数による誤差の計算、この場合は平均2乗誤差
#    loss = F.mean_squared_error(y_train_v, t_train_v)
#    loss.backward()
#    
#    # Optimizer による重みの更新
#    optimizer.update()

# Trainer を使う
train_iter = iterators.SerialIterator(train, 30) # 1回の学習30個だけ使う
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (5000, 'epoch'))
trainer.extend(extensions.ProgressBar())
trainer.run()

     total [..................................................]  0.80%
this epoch [..................................................]  0.00%
       100 iter, 40 epoch / 5000 epochs
       inf iters/sec. Estimated time to finish: 0:00:00.
     total [..................................................]  1.60%
this epoch [..................................................]  0.00%
       200 iter, 80 epoch / 5000 epochs
    252.55 iters/sec. Estimated time to finish: 0:00:48.703121.
     total [#.................................................]  2.40%
this epoch [..................................................]  0.00%
       300 iter, 120 epoch / 5000 epochs
    249.09 iters/sec. Estimated time to finish: 0:00:48.978097.
     total [#.................................................]  3.20%
this epoch [..................................................]  0.00%
       400 iter, 160 epoch / 5000 epochs
    256.22 iters/sec. Estimated time to finish: 0:00:47.225622.
     total [##.......

     total [##############....................................] 28.00%
this epoch [..................................................]  0.00%
      3500 iter, 1400 epoch / 5000 epochs
    314.73 iters/sec. Estimated time to finish: 0:00:28.595964.
     total [##############....................................] 28.80%
this epoch [..................................................]  0.00%
      3600 iter, 1440 epoch / 5000 epochs
    315.23 iters/sec. Estimated time to finish: 0:00:28.233063.
     total [##############....................................] 29.60%
this epoch [..................................................]  0.00%
      3700 iter, 1480 epoch / 5000 epochs
    315.99 iters/sec. Estimated time to finish: 0:00:27.849215.
     total [###############...................................] 30.40%
this epoch [..................................................]  0.00%
      3800 iter, 1520 epoch / 5000 epochs
    316.54 iters/sec. Estimated time to finish: 0:00:27.484549.
     tot

     total [###########################.......................] 55.20%
this epoch [..................................................]  0.00%
      6900 iter, 2760 epoch / 5000 epochs
    316.22 iters/sec. Estimated time to finish: 0:00:17.709053.
     total [############################......................] 56.00%
this epoch [..................................................]  0.00%
      7000 iter, 2800 epoch / 5000 epochs
    316.89 iters/sec. Estimated time to finish: 0:00:17.355946.
     total [############################......................] 56.80%
this epoch [..................................................]  0.00%
      7100 iter, 2840 epoch / 5000 epochs
    317.52 iters/sec. Estimated time to finish: 0:00:17.006756.
     total [############################......................] 57.60%
this epoch [..................................................]  0.00%
      7200 iter, 2880 epoch / 5000 epochs
    317.98 iters/sec. Estimated time to finish: 0:00:16.667953.
     tot

     total [#########################################.........] 82.40%
this epoch [..................................................]  0.00%
     10300 iter, 4120 epoch / 5000 epochs
    333.23 iters/sec. Estimated time to finish: 0:00:06.601980.
     total [#########################################.........] 83.20%
this epoch [..................................................]  0.00%
     10400 iter, 4160 epoch / 5000 epochs
     334.4 iters/sec. Estimated time to finish: 0:00:06.279842.
     total [##########################################........] 84.00%
this epoch [..................................................]  0.00%
     10500 iter, 4200 epoch / 5000 epochs
    335.63 iters/sec. Estimated time to finish: 0:00:05.959004.
     total [##########################################........] 84.80%
this epoch [..................................................]  0.00%
     10600 iter, 4240 epoch / 5000 epochs
    336.77 iters/sec. Estimated time to finish: 0:00:05.641866.
     tot

In [29]:
# テスト
model.cleargrads()
y_test_v = model.predict(x_test_v)
y_test = y_test_v.data

In [30]:
# 正解数のカウント
correct = 0
rowCount = y_test.shape[0]
for i in range(rowCount):
    maxIndex = np.argmax(y_test[i, :])
    print(y_test[i, :], maxIndex)
    if maxIndex == t_test[i]:
        correct += 1

# 正解率
print("Correct:" , correct, "Total:", rowCount, "Accuracy:", correct / rowCount * 100, "%")

[  1.00500143e+00   1.93521380e-04  -1.67141855e-03] 0
[ 1.00321567  0.00160341 -0.00115667] 0
[  1.00537026e+00  -8.50707293e-05  -1.74383819e-03] 0
[  1.00163150e+00   2.82947719e-03  -7.40632415e-04] 0
[  9.96649981e-01   6.62927330e-03   2.18942761e-04] 0
[  1.00566006e+00  -3.40059400e-04  -1.90351903e-03] 0
[  1.00211370e+00   2.43870914e-03  -9.98780131e-04] 0
[ 1.00778008 -0.00206526 -0.00264959] 0
[  1.00629926e+00  -8.29413533e-04  -1.95114315e-03] 0
[  1.00404572e+00   8.91014934e-04  -1.53450668e-03] 0
[ 1.001562    0.00280778 -0.00100236] 0
[ 1.00717258 -0.00148602 -0.00217883] 0
[  9.94685054e-01   8.07599723e-03   4.04015183e-04] 0
[  9.96546209e-01   6.63547218e-03   1.52900815e-04] 0
[  1.00459647e+00   4.98339534e-04  -1.59429014e-03] 0
[  9.97166932e-01   6.19344413e-03  -1.65551901e-05] 0
[ 1.00720227 -0.00155742 -0.00238733] 0
[  1.00194335e+00   2.56161392e-03  -9.92640853e-04] 0
[  1.00618935e+00  -7.64325261e-04  -2.05485523e-03] 0
[  1.00029159e+00   3.86826694