In [1]:
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from simple_convnet import SimpleConvNet
from common.trainer import Trainer

# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

# 処理に時間のかかる場合はデータを削減 
#x_train, t_train = x_train[:5000], t_train[:5000]
#x_test, t_test = x_test[:1000], t_test[:1000]

max_epochs = 20

network = SimpleConvNet(input_dim=(1,28,28), 
                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        hidden_size=100, output_size=10, weight_init_std=0.01)
                        
trainer = Trainer(network, x_train, t_train, x_test, t_test,
                  epochs=max_epochs, mini_batch_size=100,
                  optimizer='Adam', optimizer_param={'lr': 0.001},
                  evaluate_sample_num_per_epoch=1000)
trainer.train()

# パラメータの保存
network.save_params("params.pkl")
print("Saved Network Parameters!")

# グラフの描画
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)
plt.plot(x, trainer.test_acc_list, marker='s', label='test', markevery=2)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

train loss:2.3001204113803246
=== epoch:1, train acc:0.092, test acc:0.088 ===
train loss:2.2995025000781957
train loss:2.2958372597560714
train loss:2.2914574687201537
train loss:2.285662852119448
train loss:2.2775827827020807
train loss:2.261566101858999
train loss:2.2443177871594138
train loss:2.232632066087915
train loss:2.22218292024604
train loss:2.179217144964022
train loss:2.1536313502174895
train loss:2.134199290137875
train loss:2.051738431834382
train loss:2.0279708145292448
train loss:1.9861838451866156
train loss:1.9535446153641727
train loss:1.8837259802333266
train loss:1.789418693475663
train loss:1.7292471012178563
train loss:1.6030155806299304
train loss:1.478945011475561
train loss:1.4661727070188895
train loss:1.3963692900359987
train loss:1.3005873323796662
train loss:1.1941585241708086
train loss:1.1104285172992476
train loss:1.1725037774402807
train loss:1.0014485632802421
train loss:0.9140743489194716
train loss:0.9221992789124477
train loss:0.9935141395103136
t

KeyboardInterrupt: 