In [1]:
# coding: utf-8
import sys
sys.path.append('..')
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb
from rnnlm import Rnnlm


# ハイパーパラメータの設定
batch_size = 2
wordvec_size = 10
hidden_size = 10  # RNNの隠れ状態ベクトルの要素数
time_size = 35  # RNNを展開するサイズ
lr = 20.0
max_epoch = 1
max_grad = 0.25

# 学習データの読み込み
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

# モデルの生成
model = Rnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

# 勾配クリッピングを適用して学習
trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad,
            eval_interval=20)
trainer.plot(ylim=(0, 500))

# テストデータで評価
model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)

# パラメータの保存
model.save_params()

| epoch 1 |  iter 1 / 13279 | time 0[s] | perplexity 10001.43
| epoch 1 |  iter 21 / 13279 | time 0[s] | perplexity 2294.29
| epoch 1 |  iter 41 / 13279 | time 1[s] | perplexity 1407.56
| epoch 1 |  iter 61 / 13279 | time 2[s] | perplexity 1010.26
| epoch 1 |  iter 81 / 13279 | time 3[s] | perplexity 923.27
| epoch 1 |  iter 101 / 13279 | time 4[s] | perplexity 745.16
| epoch 1 |  iter 121 / 13279 | time 5[s] | perplexity 724.46
| epoch 1 |  iter 141 / 13279 | time 6[s] | perplexity 583.22
| epoch 1 |  iter 161 / 13279 | time 7[s] | perplexity 704.03
| epoch 1 |  iter 181 / 13279 | time 8[s] | perplexity 671.42
| epoch 1 |  iter 201 / 13279 | time 9[s] | perplexity 649.97
| epoch 1 |  iter 221 / 13279 | time 10[s] | perplexity 625.38
| epoch 1 |  iter 241 / 13279 | time 11[s] | perplexity 576.72
| epoch 1 |  iter 261 / 13279 | time 11[s] | perplexity 504.92
| epoch 1 |  iter 281 / 13279 | time 12[s] | perplexity 519.49
| epoch 1 |  iter 301 / 13279 | time 13[s] | perplexity 603.03
| ep

| epoch 1 |  iter 2581 / 13279 | time 118[s] | perplexity 346.01
| epoch 1 |  iter 2601 / 13279 | time 119[s] | perplexity 297.69
| epoch 1 |  iter 2621 / 13279 | time 120[s] | perplexity 417.58
| epoch 1 |  iter 2641 / 13279 | time 120[s] | perplexity 377.25
| epoch 1 |  iter 2661 / 13279 | time 121[s] | perplexity 326.89
| epoch 1 |  iter 2681 / 13279 | time 122[s] | perplexity 310.67
| epoch 1 |  iter 2701 / 13279 | time 123[s] | perplexity 377.35
| epoch 1 |  iter 2721 / 13279 | time 124[s] | perplexity 365.15
| epoch 1 |  iter 2741 / 13279 | time 125[s] | perplexity 370.19
| epoch 1 |  iter 2761 / 13279 | time 126[s] | perplexity 304.08
| epoch 1 |  iter 2781 / 13279 | time 127[s] | perplexity 272.94
| epoch 1 |  iter 2801 / 13279 | time 128[s] | perplexity 276.32
| epoch 1 |  iter 2821 / 13279 | time 129[s] | perplexity 324.44
| epoch 1 |  iter 2841 / 13279 | time 130[s] | perplexity 439.47
| epoch 1 |  iter 2861 / 13279 | time 130[s] | perplexity 481.04
| epoch 1 |  iter 2881 / 

| epoch 1 |  iter 5121 / 13279 | time 234[s] | perplexity 382.57
| epoch 1 |  iter 5141 / 13279 | time 235[s] | perplexity 377.93
| epoch 1 |  iter 5161 / 13279 | time 235[s] | perplexity 264.44
| epoch 1 |  iter 5181 / 13279 | time 236[s] | perplexity 328.54
| epoch 1 |  iter 5201 / 13279 | time 237[s] | perplexity 373.80
| epoch 1 |  iter 5221 / 13279 | time 238[s] | perplexity 464.83
| epoch 1 |  iter 5241 / 13279 | time 239[s] | perplexity 412.27
| epoch 1 |  iter 5261 / 13279 | time 240[s] | perplexity 347.72
| epoch 1 |  iter 5281 / 13279 | time 241[s] | perplexity 347.72
| epoch 1 |  iter 5301 / 13279 | time 242[s] | perplexity 400.38
| epoch 1 |  iter 5321 / 13279 | time 243[s] | perplexity 378.99
| epoch 1 |  iter 5341 / 13279 | time 244[s] | perplexity 372.73
| epoch 1 |  iter 5361 / 13279 | time 245[s] | perplexity 364.78
| epoch 1 |  iter 5381 / 13279 | time 245[s] | perplexity 370.26
| epoch 1 |  iter 5401 / 13279 | time 246[s] | perplexity 394.97
| epoch 1 |  iter 5421 / 

| epoch 1 |  iter 7661 / 13279 | time 349[s] | perplexity 410.88
| epoch 1 |  iter 7681 / 13279 | time 350[s] | perplexity 330.56
| epoch 1 |  iter 7701 / 13279 | time 351[s] | perplexity 353.72
| epoch 1 |  iter 7721 / 13279 | time 352[s] | perplexity 462.13
| epoch 1 |  iter 7741 / 13279 | time 353[s] | perplexity 322.54
| epoch 1 |  iter 7761 / 13279 | time 354[s] | perplexity 337.43
| epoch 1 |  iter 7781 / 13279 | time 355[s] | perplexity 287.75
| epoch 1 |  iter 7801 / 13279 | time 356[s] | perplexity 202.56
| epoch 1 |  iter 7821 / 13279 | time 357[s] | perplexity 329.05
| epoch 1 |  iter 7841 / 13279 | time 358[s] | perplexity 362.31
| epoch 1 |  iter 7861 / 13279 | time 358[s] | perplexity 241.96
| epoch 1 |  iter 7881 / 13279 | time 359[s] | perplexity 428.54
| epoch 1 |  iter 7901 / 13279 | time 360[s] | perplexity 282.39
| epoch 1 |  iter 7921 / 13279 | time 361[s] | perplexity 305.21
| epoch 1 |  iter 7941 / 13279 | time 362[s] | perplexity 366.26
| epoch 1 |  iter 7961 / 

| epoch 1 |  iter 10181 / 13279 | time 464[s] | perplexity 276.32
| epoch 1 |  iter 10201 / 13279 | time 464[s] | perplexity 430.16
| epoch 1 |  iter 10221 / 13279 | time 465[s] | perplexity 348.23
| epoch 1 |  iter 10241 / 13279 | time 466[s] | perplexity 436.67
| epoch 1 |  iter 10261 / 13279 | time 467[s] | perplexity 315.47
| epoch 1 |  iter 10281 / 13279 | time 468[s] | perplexity 259.71
| epoch 1 |  iter 10301 / 13279 | time 469[s] | perplexity 221.98
| epoch 1 |  iter 10321 / 13279 | time 470[s] | perplexity 319.83
| epoch 1 |  iter 10341 / 13279 | time 471[s] | perplexity 324.08
| epoch 1 |  iter 10361 / 13279 | time 472[s] | perplexity 297.10
| epoch 1 |  iter 10381 / 13279 | time 473[s] | perplexity 187.47
| epoch 1 |  iter 10401 / 13279 | time 473[s] | perplexity 289.27
| epoch 1 |  iter 10421 / 13279 | time 474[s] | perplexity 355.34
| epoch 1 |  iter 10441 / 13279 | time 475[s] | perplexity 446.68
| epoch 1 |  iter 10461 / 13279 | time 476[s] | perplexity 344.31
| epoch 1 

| epoch 1 |  iter 12681 / 13279 | time 577[s] | perplexity 256.93
| epoch 1 |  iter 12701 / 13279 | time 578[s] | perplexity 222.21
| epoch 1 |  iter 12721 / 13279 | time 579[s] | perplexity 313.50
| epoch 1 |  iter 12741 / 13279 | time 580[s] | perplexity 405.99
| epoch 1 |  iter 12761 / 13279 | time 581[s] | perplexity 429.26
| epoch 1 |  iter 12781 / 13279 | time 582[s] | perplexity 358.07
| epoch 1 |  iter 12801 / 13279 | time 583[s] | perplexity 364.48
| epoch 1 |  iter 12821 / 13279 | time 584[s] | perplexity 374.65
| epoch 1 |  iter 12841 / 13279 | time 585[s] | perplexity 279.84
| epoch 1 |  iter 12861 / 13279 | time 586[s] | perplexity 370.67
| epoch 1 |  iter 12881 / 13279 | time 586[s] | perplexity 310.18
| epoch 1 |  iter 12901 / 13279 | time 587[s] | perplexity 412.59
| epoch 1 |  iter 12921 / 13279 | time 588[s] | perplexity 326.53
| epoch 1 |  iter 12941 / 13279 | time 589[s] | perplexity 322.85
| epoch 1 |  iter 12961 / 13279 | time 590[s] | perplexity 428.71
| epoch 1 

<Figure size 640x480 with 1 Axes>

evaluating perplexity ...
234 / 235
test perplexity:  358.5866927949406


NameError: name 'GPU' is not defined