# Serializerによるモデルの保存（読み込み）

In [1]:
# coding: UTF-8

import chainer
from chainer import Variable, Chain, optimizers, serializers
import chainer.links as L
import chainer.functions as F

# Trainer用
from chainer.datasets import tuple_dataset
from chainer import training, iterators
from chainer.training import extensions

import numpy as np
from sklearn import datasets # Scikit Learn にIris dataset が含まれている

In [2]:
# Iris データの読み込み
iris_data = datasets.load_iris()

In [3]:
x = iris_data.data.astype(np.float32)
t = iris_data.target
n = t.size

In [4]:
# 教師データの下処理
t_matrix = np.zeros(3 * n).reshape(n, 3).astype(np.float32)
for i in range(n):
    t_matrix[i, t[i]] = 1.0

In [5]:
# 訓練用データとテスト用データ 半分が訓練用データで残りがテスト用データ
indexes = np.arange(n)
indexes_train = indexes[indexes%2 != 0]
indexes_test = indexes[indexes%2 == 0]

x_train = x[indexes_train, :] # 訓練用 入力
t_train = t_matrix[indexes_train, :] # 訓練用 正解
x_test = x[indexes_test, :] # テスト用 入力
t_test = t[indexes_test] # テスト用 正解

# tuple_dataset を使う
train = tuple_dataset.TupleDataset(x_train, t_train)
x_test_v = Variable(x_test)

In [6]:
# Chain の記述
class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1 = L.Linear(4, 6),
            l2 = L.Linear(6, 6),
            l3 = L.Linear(6, 3),
        )
    
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x), t)
        
    def predict(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        h3 = self.l3(h2)
        return h3

In [7]:
# モデルとOptimizerの設定
model = IrisChain()
optimizer = optimizers.Adam()
optimizer.setup(model)

In [8]:
# 学習

# Trainer を使う
#train_iter = iterators.SerialIterator(train, 30) # 1回の学習30個だけ使う
#updater = training.StandardUpdater(train_iter, optimizer)
#trainer = training.Trainer(updater, (5000, 'epoch'))
#trainer.extend(extensions.ProgressBar())
#trainer.run()

In [10]:
# モデルの読み込み
serializers.load_npz("my_iris.npz", model)

In [11]:
# テスト
model.cleargrads()
y_test_v = model.predict(x_test_v)
y_test = y_test_v.data

In [12]:
# 正解数のカウント
correct = 0
rowCount = y_test.shape[0]
for i in range(rowCount):
    maxIndex = np.argmax(y_test[i, :])
    print(y_test[i, :], maxIndex)
    if maxIndex == t_test[i]:
        correct += 1

# 正解率
print("Correct:" , correct, "Total:", rowCount, "Accuracy:", correct / rowCount * 100, "%")

[ 1.00479472 -0.00155996 -0.00389822] 0
[  1.00487888e+00  -6.99326396e-04  -5.55403531e-03] 0
[ 1.00367308 -0.00137131 -0.00209506] 0
[  1.00009620e+00   7.42450356e-04   3.67090106e-04] 0
[ 0.99768138  0.00497569 -0.00294013] 0
[ 1.00406015 -0.00160502 -0.00247093] 0
[ 1.00387084  0.00185962 -0.0077184 ] 0
[ 1.01073432 -0.00662695 -0.00699253] 0
[ 1.00643158 -0.00479363 -0.00227375] 0
[  1.00063419e+00  -6.70999289e-05   6.94081187e-04] 0
[ 0.9988575   0.00292368 -0.00119169] 0
[ 1.00872374 -0.00496553 -0.00560804] 0
[ 0.98563766  0.01047422  0.00918202] 0
[ 0.99441898  0.00414182  0.00336616] 0
[ 1.0059092  -0.00168888 -0.00577946] 0
[  9.94938970e-01   6.08690083e-03  -9.65148211e-05] 0
[  1.00091159e+00  -6.83113933e-04   2.16142833e-03] 0
[ 1.001683    0.00276209 -0.00534149] 0
[ 1.01021051 -0.00447099 -0.00900705] 0
[ 1.00216496  0.00163634 -0.00476964] 0
[ 1.00576472 -0.00284313 -0.00395723] 0
[  1.00243473e+00   4.74169850e-04  -2.96957791e-03] 0
[ 0.98777211  0.00713946  0.01