Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
274 lines (239 sloc) 12.2 KB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Create the data for the LSTM.
"""
import os
import sys
import argparse
import numpy as np
import h5py
import itertools
from collections import defaultdict
class Indexer:
def __init__(self, symbols = ["<blank>","<unk>","<s>","</s>"]):
self.vocab = defaultdict(int)
self.PAD = symbols[0]
self.UNK = symbols[1]
self.BOS = symbols[2]
self.EOS = symbols[3]
self.d = {self.PAD: 1, self.UNK: 2, self.BOS: 3, self.EOS: 4}
def add_w(self, ws):
for w in ws:
if w not in self.d:
self.d[w] = len(self.d) + 1
def convert(self, w):
return self.d[w] if w in self.d else self.d[self.UNK]
def convert_sequence(self, ls):
return [self.convert(l) for l in ls]
def clean(self, s):
s = s.replace(self.PAD, "")
s = s.replace(self.BOS, "")
s = s.replace(self.EOS, "")
return s
def write(self, outfile):
out = open(outfile, "w")
items = [(v, k) for k, v in self.d.iteritems()]
items.sort()
for v, k in items:
print >>out, k, v
out.close()
def prune_vocab(self, k, cnt = False):
vocab_list = [(word, count) for word, count in self.vocab.iteritems()]
if cnt:
self.pruned_vocab = {pair[0]: pair[1] for pair in vocab_list if pair[1] > k}
else:
vocab_list.sort(key = lambda x: x[1], reverse=True)
k = min(k, len(vocab_list))
self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]}
for word in self.pruned_vocab:
if word not in self.d:
self.d[word] = len(self.d) + 1
def load_vocab(self, vocab_file):
self.d = {}
for line in open(vocab_file, 'r'):
v, k = line.strip().split()
self.d[v] = int(k)
def pad(ls, length, symbol):
if len(ls) >= length:
return ls[:length]
return ls + [symbol] * (length -len(ls))
def get_data(args):
src_indexer = Indexer(["<blank>","<unk>","<s>","</s>"])
target_indexer = Indexer(["<blank>","<unk>","<s>","</s>"])
def make_vocab(srcfile, targetfile, srcseqlength, targetseqlength, train=1):
num_sents = 0
for _, (src_orig, targ_orig) in \
enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r'))):
src_orig = src_indexer.clean(src_orig.strip())
targ_orig = target_indexer.clean(targ_orig.strip())
targ = targ_orig.strip().split()
src = src_orig.strip().split()
if len(targ) > targetseqlength or len(src) > srcseqlength or len(targ) < 1 or len(src) < 1:
continue
num_sents += 1
if train == 1:
for word in targ:
target_indexer.vocab[word] += 1
for word in src:
src_indexer.vocab[word] += 1
return num_sents
def convert(srcfile, targetfile, batchsize, srcseqlength, targetseqlength, outfile, num_sents,
max_sent_l=0, shuffle=0):
newsrcseqlength = srcseqlength + 2 #add 2 for EOS and BOS
newtargetseqlength = targetseqlength + 2
targets = np.zeros((num_sents, newtargetseqlength), dtype=int)
target_output = np.zeros((num_sents, newtargetseqlength), dtype=int)
sources = np.zeros((num_sents, newsrcseqlength), dtype=int)
source_lengths = np.zeros((num_sents,), dtype=int)
target_lengths = np.zeros((num_sents,), dtype=int)
dropped = 0
sent_id = 0
for _, (src_orig, targ_orig) in \
enumerate(itertools.izip(open(srcfile,'r'), open(targetfile,'r'))):
src_orig = src_indexer.clean(src_orig.strip())
targ_orig = target_indexer.clean(targ_orig.strip())
targ = [target_indexer.BOS] + targ_orig.strip().split() + [target_indexer.EOS]
src = [src_indexer.BOS] + src_orig.strip().split() + [src_indexer.EOS]
max_sent_l = max(len(targ), len(src), max_sent_l)
if len(targ) > newtargetseqlength or len(src) > newsrcseqlength or len(targ) < 3 or len(src) < 3:
dropped += 1
continue
targ = pad(targ, newtargetseqlength+1, target_indexer.PAD)
targ = target_indexer.convert_sequence(targ)
targ = np.array(targ, dtype=int)
src = pad(src, newsrcseqlength, src_indexer.PAD)
src = src_indexer.convert_sequence(src)
src = np.array(src, dtype=int)
targets[sent_id] = np.array(targ[:-1],dtype=int)
target_lengths[sent_id] = (targets[sent_id] != 1).sum()
target_output[sent_id] = np.array(targ[1:],dtype=int)
sources[sent_id] = np.array(src, dtype=int)
source_lengths[sent_id] = (sources[sent_id] != 1).sum()
sent_id += 1
if sent_id % 100000 == 0:
print("{}/{} sentences processed".format(sent_id, num_sents))
print(sent_id, num_sents)
if shuffle == 1:
rand_idx = np.random.permutation(sent_id)
targets = targets[rand_idx]
target_output = target_output[rand_idx]
sources = sources[rand_idx]
source_lengths = source_lengths[rand_idx]
target_lengths = target_lengths[rand_idx]
#break up batches based on source lengths
source_lengths = source_lengths[:sent_id]
source_sort = np.argsort(source_lengths)
sources = sources[source_sort]
targets = targets[source_sort]
target_output = target_output[source_sort]
target_l = target_lengths[source_sort]
source_l = source_lengths[source_sort]
curr_l = 0
l_location = [] #idx where sent length changes
for j,i in enumerate(source_sort):
if source_lengths[i] > curr_l:
curr_l = source_lengths[i]
l_location.append(j+1)
l_location.append(len(sources))
#get batch sizes
curr_idx = 1
batch_idx = [1]
nonzeros = []
batch_l = []
batch_w = []
target_l_max = []
for i in range(len(l_location)-1):
while curr_idx < l_location[i+1]:
curr_idx = min(curr_idx + batchsize, l_location[i+1])
batch_idx.append(curr_idx)
for i in range(len(batch_idx)-1):
batch_l.append(batch_idx[i+1] - batch_idx[i])
batch_w.append(source_l[batch_idx[i]-1])
nonzeros.append((target_output[batch_idx[i]-1:batch_idx[i+1]-1] != 1).sum().sum())
target_l_max.append(max(target_l[batch_idx[i]-1:batch_idx[i+1]-1]))
# Write output
f = h5py.File(outfile, "w")
f["source"] = sources
f["target"] = targets
f["target_output"] = target_output
f["target_l"] = np.array(target_l_max, dtype=int)
f["target_l_all"] = target_l
f["batch_l"] = np.array(batch_l, dtype=int)
f["batch_w"] = np.array(batch_w, dtype=int)
f["batch_idx"] = np.array(batch_idx[:-1], dtype=int)
f["target_nonzeros"] = np.array(nonzeros, dtype=int)
f["source_size"] = np.array([len(src_indexer.d)])
f["target_size"] = np.array([len(target_indexer.d)])
print("Saved {} sentences (dropped {} due to length)".format(len(f["source"]), dropped))
f.close()
return max_sent_l
print("First pass through data to get vocab...")
num_sents_train = make_vocab(args.srcfile, args.targetfile,
args.srcseqlength, args.targetseqlength)
print("Number of sentences in training: {}".format(num_sents_train))
num_sents_valid = make_vocab(args.srcvalfile, args.targetvalfile,
args.srcseqlength, args.targetseqlength, 0)
print("Number of sentences in valid: {}".format(num_sents_valid))
#prune and write vocab
src_indexer.prune_vocab(args.srcvocabminfreq, True)
target_indexer.prune_vocab(args.targetvocabminfreq, True)
if args.srcvocabfile != '':
print('Loading pre-specified source vocab from ' + args.srcvocabfile)
src_indexer.load_vocab(args.srcvocabfile)
if args.targetvocabfile != '':
print('Loading pre-specified target vocab from ' + args.targetvocabfile)
target_indexer.load_vocab(args.targetvocabfile)
src_indexer.write(args.outputfile + ".src.dict")
target_indexer.write(args.outputfile + ".targ.dict")
print("Source vocab size: Original = {}, Pruned = {}".format(len(src_indexer.vocab),
len(src_indexer.d)))
print("Target vocab size: Original = {}, Pruned = {}".format(len(target_indexer.vocab),
len(target_indexer.d)))
max_sent_l = 0
max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.batchsize, args.srcseqlength,
args.targetseqlength, args.outputfile + "-val.hdf5", num_sents_valid,
max_sent_l, args.shuffle)
max_sent_l = convert(args.srcfile, args.targetfile, args.batchsize, args.srcseqlength,
args.targetseqlength,
args.outputfile + "-train.hdf5", num_sents_train,
max_sent_l, args.shuffle)
print("Max sent length (before dropping): {}".format(max_sent_l))
def main(arguments):
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--srcvocabminfreq', help="Source vocab count limit. All words that occurred"
"less than this amount are replaced with UNK.",
type=int, default=10)
parser.add_argument('--targetvocabminfreq', help="Source vocab count limit. All words that occurred"
"less than this amount are replaced with UNK.",
type=int, default=10)
parser.add_argument('--srcfile', help="Path to source training data, "
"where each line represents a single "
"source/target sequence.", required=True)
parser.add_argument('--targetfile', help="Path to target training data, "
"where each line represents a single "
"source/target sequence.", required=True)
parser.add_argument('--srcvalfile', help="Path to source validation data.", required=True)
parser.add_argument('--targetvalfile', help="Path to target validation data.", required=True)
parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=128)
parser.add_argument('--srcseqlength', help="Maximum source sequence length. Sequences longer "
"than this are dropped.", type=int, default=50)
parser.add_argument('--targetseqlength', help="Maximum target sequence length. Sequences longer "
"than this are dropped.", type=int, default=50)
parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str, required=True)
parser.add_argument('--srcvocabfile', help="If working with a preset vocab, "
"then including this will ignore srcvocabsize and use the"
"vocab provided here.",
type = str, default='')
parser.add_argument('--targetvocabfile', help="If working with a preset vocab, "
"then including this will ignore targetvocabsize and "
"use the vocab provided here.",
type = str, default='')
parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on "
"source length).",
type = int, default = 0)
args = parser.parse_args(arguments)
get_data(args)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))