In [None]:
# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from simple_convnet import SimpleConvNet
from common.trainer import Trainer

In [None]:
# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

In [None]:
# 訓練画像（サンプル）
x_train[0]

In [None]:
# 訓練ラベル（サンプル）
t_train[:10]

In [None]:
# サンプルイメージ出力
fig = plt.figure(figsize=(6,3))
for c, (image, label) in enumerate(zip(x_train[:10], t_train[:10])):
    subplot = fig.add_subplot(2,5,c+1)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d' % label)
    subplot.imshow(image.reshape((28,28)), vmin=0, vmax=1,
                   cmap=plt.cm.gray_r, interpolation="nearest")

In [None]:
# 予測イメージ出力
network = SimpleConvNet(input_dim=(1,28,28), 
                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        hidden_size=100, output_size=10, weight_init_std=0.01)

network.load_params('params.pkl')

prediction = network.predict(x_train[:1])
prediction = np.argmax(prediction, axis=1)

fig = plt.figure(figsize=(6,3))
for c, (image, label) in enumerate(zip(x_train[:10], t_train[:10])):
    
    prediction = network.predict(x_train[c:c+1])
    prediction = np.argmax(prediction, axis=1)
    
    subplot = fig.add_subplot(2,5,c+1)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.set_title('%d / %d' % (prediction, label))
    subplot.imshow(image.reshape((28,28)), vmin=0, vmax=1,
                   cmap=plt.cm.gray_r, interpolation="nearest")

In [None]:
# 処理に時間のかかる場合はデータを削減 
x_train, t_train = x_train[:5000], t_train[:5000]
x_test, t_test = x_test[:1000], t_test[:1000]

max_epochs = 20

network = SimpleConvNet(input_dim=(1,28,28), 
                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        hidden_size=100, output_size=10, weight_init_std=0.01)

trainer = Trainer(network, x_train, t_train, x_test, t_test,
                  epochs=max_epochs, mini_batch_size=100,
                  optimizer='Adam', optimizer_param={'lr': 0.001},
                  evaluate_sample_num_per_epoch=1000)
trainer.train()

In [None]:
# パラメータの保存
network.save_params("params.pkl")
print("Saved Network Parameters!")

In [None]:
# グラフの描画
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)
plt.plot(x, trainer.test_acc_list, marker='s', label='test', markevery=2)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

In [None]:
# 答え合わせ
network.load_params('params.pkl')

train_size = x_test.shape[0]
batch_size = 3000
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_test[batch_mask]
t_batch = t_test[batch_mask]

fig = plt.figure(figsize=(6,12))
for i in range(10):
    c = 1
    for index, (image, actual) in enumerate(zip(x_batch, t_batch)):
        prediction = network.predict(x_batch[index:index+1])
        prediction = np.argmax(prediction, axis=1)
        if prediction != i:
            continue    
        if (c < 4 and i == actual) or (c >= 4 and i != actual):
            subplot = fig.add_subplot(10, 6, i*6+c)
            subplot.set_xticks([])
            subplot.set_yticks([])
            subplot.set_title('%d / %d' % (prediction, actual))
            subplot.imshow(image.reshape((28,28)), vmin=0, vmax=1,
                           cmap=plt.cm.gray_r, interpolation="nearest")
            c += 1
            if c > 6:
                break