RNN Training code with Penn Treebank (ptb) dataset
from __future__ import print_function
import os
import sys
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, iterators, serializers, optimizers
from import extensions
from RNN import RNN
from RNN2 import RNN2
from RNN3 import RNN3
from RNNForLM import RNNForLM
from parallel_sequential_iterator import ParallelSequentialIterator
from bptt_updater import BPTTUpdater
# Routine to rewrite the result dictionary of LogReport to add perplexity
# values
def compute_perplexity(result):
result['perplexity'] = np.exp(result['main/loss'])
if 'validation/main/loss' in result:
result['val_perplexity'] = np.exp(result['validation/main/loss'])
def main():
archs = {
'rnn': RNN,
'rnn2': RNN2,
'rnn3': RNN3,
'lstm': RNNForLM
parser = argparse.ArgumentParser(description='RNN example')
parser.add_argument('--arch', '-a', choices=archs.keys(),
default='rnn', help='Net architecture')
parser.add_argument('--unit', '-u', type=int, default=100,
help='Number of RNN units in each layer')
parser.add_argument('--bproplen', '-l', type=int, default=20,
help='Number of words in each mini-batch '
'(= length of truncated BPTT)')
parser.add_argument('--batchsize', '-b', type=int, default=10,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=10,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# Architecture: {}'.format(args.arch))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
# 1. Load dataset: Penn Tree Bank long word sequence dataset
train, val, test = chainer.datasets.get_ptb_words()
n_vocab = max(train) + 1 # train is just an array of integers
print('# vocab: {}'.format(n_vocab))
# 2. Setup model
model = archs[args.arch](n_vocab=n_vocab,
n_units=args.unit) # , activation=F.tanh
classifier_model = L.Classifier(model)
classifier_model.compute_accuracy = False # we only want the perplexity
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
classifier_model.to_gpu() # Copy the model to the GPU
eval_classifier_model = classifier_model.copy() # Model with shared params and distinct states
eval_model = classifier_model.predictor
# 2. Setup an optimizer
optimizer = optimizers.Adam(alpha=0.001)
#optimizer = optimizers.MomentumSGD()
# 4. Setup an Iterator
train_iter =ParallelSequentialIterator(train, args.batchsize)
val_iter = ParallelSequentialIterator(val, 1, repeat=False)
test_iter = ParallelSequentialIterator(test, 1, repeat=False)
# 5. Setup an Updater
updater = BPTTUpdater(train_iter, optimizer, args.bproplen, args.gpu)
# 6. Setup a trainer (and extensions)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(val_iter, eval_classifier_model,
# Reset the RNN state at the beginning of each evaluation
eval_hook=lambda _: eval_model.reset_state())
trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
interval = 500
trigger=(interval, 'iteration')))
['epoch', 'iteration', 'perplexity', 'val_perplexity', 'elapsed_time']
), trigger=(interval, 'iteration'))
['perplexity', 'val_perplexity'],
x_key='epoch', file_name='perplexity.png'))
# Resume from a snapshot
if args.resume:
serializers.load_npz(args.resume, trainer)
# Run the training
.format(args.out, args.arch), model)
# Evaluate the final model
evaluator = extensions.Evaluator(test_iter, eval_classifier_model, device=args.gpu)
result = evaluator()
print('test perplexity:', np.exp(float(result['main/loss'])))
if __name__ == '__main__':