In [2]:
import sys
sys.path.append('..')
import matplotlib.pyplot as plt
import numpy as np
from common.optimizer import SGD
from dataset import ptb
from simple_rnnlm import SimpleRnnlm

# initialize the hyper parameters
batch_size = 10
wordvec_size = 100
hidden_size = 100
time_size = 5
lr = 0.1
max_epoch = 100

# load data
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_size = 1000
corpus = corpus[:corpus_size]
vocab_size = int(max(corpus) + 1)

xs = corpus[:-1]
ts = corpus[1:]
data_size = len(xs)
print('corpus size: %d, vocabulary size: %d' % (corpus_size, vocab_size))

max_iters = data_size // (batch_size * time_size)
time_idx = 0
total_loss = 0
loss_count = 0
ppl_list = []

model = SimpleRnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)

jump = (corpus_size - 1) // batch_size
offsets = [i * jump for i in range(batch_size)]

for epoch in range(max_epoch):
    for iter in range(max_iters):
        batch_x = np.empty((batch_size, time_size), dtype='i')
        batch_t = np.empty((batch_size, time_size), dtype='i')
        for t in range(time_size):
            for i, offset in enumerate(offsets):
                batch_x[i, t] = xs[(offset + time_idx) % data_size]
                batch_t[i, t] = ts[(offset + time_idx) % data_size]
            time_idx += 1
        
        loss = model.forward(batch_x, batch_t)
        model.backward()
        optimizer.update(model.params, model.grads)
        total_loss += loss
        loss_count += 1
    
    ppl = np.exp(total_loss / loss_count)
    print('epoch %d | perplexity %.2f' % (epoch + 1, ppl))
    total_loss, loss_count = 0, 0

corpus size: 1000, vocabulary size: 418
epoch 1 | perplexity 408.86
epoch 2 | perplexity 298.10
epoch 3 | perplexity 230.38
epoch 4 | perplexity 217.98
epoch 5 | perplexity 207.37
epoch 6 | perplexity 202.58
epoch 7 | perplexity 198.85
epoch 8 | perplexity 196.50
epoch 9 | perplexity 191.31
epoch 10 | perplexity 192.50
epoch 11 | perplexity 188.17
epoch 12 | perplexity 191.69
epoch 13 | perplexity 190.39
epoch 14 | perplexity 190.47
epoch 15 | perplexity 189.24
epoch 16 | perplexity 184.93
epoch 17 | perplexity 183.75
epoch 18 | perplexity 180.25
epoch 19 | perplexity 181.79
epoch 20 | perplexity 182.59
epoch 21 | perplexity 181.20
epoch 22 | perplexity 176.38
epoch 23 | perplexity 175.20
epoch 24 | perplexity 175.59
epoch 25 | perplexity 173.35
epoch 26 | perplexity 171.39
epoch 27 | perplexity 167.57
epoch 28 | perplexity 165.10
epoch 29 | perplexity 161.58
epoch 30 | perplexity 156.91
epoch 31 | perplexity 156.30
epoch 32 | perplexity 153.75
epoch 33 | perplexity 152.54
epoch 34 | p