Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
101 lines (81 sloc) 3.28 KB
"""Inference/predict code for simple_sequence dataset
model must be trained before inference, must be executed beforehand.
from __future__ import print_function
import argparse
import os
import sys
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, iterators, serializers, optimizers, Variable, cuda
from import extensions
from RNN import RNN
from RNN2 import RNN2
from RNN3 import RNN3
from RNNForLM import RNNForLM
def main():
archs = {
'rnn': RNN,
'rnn2': RNN2,
'rnn3': RNN3,
'lstm': RNNForLM
parser = argparse.ArgumentParser(description='simple_sequence RNN predict code')
parser.add_argument('--arch', '-a', choices=archs.keys(),
default='rnn', help='Net architecture')
#parser.add_argument('--batchsize', '-b', type=int, default=64,
# help='Number of images in each mini-batch')
parser.add_argument('--unit', '-u', type=int, default=100,
help='Number of LSTM units in each layer')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--primeindex', '-p', type=int, default=1,
help='base index data, used for sequence generation')
parser.add_argument('--length', '-l', type=int, default=100,
help='length of the generated sequence')
parser.add_argument('--modelpath', '-m', default='',
help='Model path to be loaded')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
#print('# Minibatch-size: {}'.format(args.batchsize))
train, val, test = chainer.datasets.get_ptb_words()
n_vocab = max(train) + 1 # train is just an array of integers
print('#vocab =', n_vocab)
# load vocabulary
ptb_word_id_dict = chainer.datasets.get_ptb_words_vocabulary()
ptb_id_word_dict = dict((v, k) for k, v in ptb_word_id_dict.items())
# Model Setup
model = archs[args.arch](n_vocab=n_vocab, n_units=args.unit)
classifier_model = L.Classifier(model)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
classifier_model.to_gpu() # Copy the model to the GPU
xp = np if args.gpu < 0 else cuda.cupy
if args.modelpath:
serializers.load_npz(args.modelpath, model)
serializers.load_npz('result/{}_ptb.model'.format(args.arch), model)
# Dataset preparation
prev_index = args.primeindex
# Predict
predicted_sequence = [prev_index]
for i in range(args.length):
prev = chainer.Variable(xp.array([prev_index], dtype=xp.int32))
current = model(prev)
current_index = np.argmax(cuda.to_cpu(
prev_index = current_index
predicted_text_list = [ptb_id_word_dict[i] for i in predicted_sequence]
print('Predicted sequence: ', predicted_sequence)
print('Predicted text: ', ' '.join(predicted_text_list))
if __name__ == '__main__':