Skip to content

Commit

Permalink
Merge pull request #85 from neulab/max-len
Browse files Browse the repository at this point in the history
max_src_len, max_trg_len options
  • Loading branch information
neubig committed Jun 5, 2017
2 parents 5c90fea + 060aaa3 commit 60c3cfd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
2 changes: 2 additions & 0 deletions examples/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ defaults:
dev_src: examples/data/head.ja
dev_trg: examples/data/head.en
default_layer_dim: 64
encoder:
dropout: 0.0
decode:
src_file: examples/data/head.ja
evaluate:
Expand Down
27 changes: 22 additions & 5 deletions xnmt/xnmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
Option("train_trg"),
Option("dev_src"),
Option("dev_trg"),
Option("max_src_len", int, required=False, help_str="Remove sentences from training/dev data that are longer than this on the source side"),
Option("max_trg_len", int, required=False, help_str="Remove sentences from training/dev data that are longer than this on the target side"),
Option("model_file"),
Option("pretrained_model_file", default_value="", help_str="Path of pre-trained model file"),
Option("input_vocab", default_value="", help_str="Path of fixed input vocab file"),
Expand Down Expand Up @@ -170,8 +172,11 @@ def create_model(self):


def read_data(self):
self.train_src = self.input_reader.read_file(self.args.train_src)
self.train_trg = self.output_reader.read_file(self.args.train_trg)
self.train_src, self.train_trg = \
self.remove_long_sents(self.input_reader.read_file(self.args.train_src),
self.output_reader.read_file(self.args.train_trg),
self.args.max_src_len, self.args.max_trg_len,
)
assert len(self.train_src) == len(self.train_trg)
self.total_train_sent = len(self.train_src)
if self.args.eval_every == None:
Expand All @@ -180,10 +185,22 @@ def read_data(self):
self.input_reader.freeze()
self.output_reader.freeze()

self.dev_src = self.input_reader.read_file(self.args.dev_src)
self.dev_trg = self.output_reader.read_file(self.args.dev_trg)
self.dev_src, self.dev_trg = \
self.remove_long_sents(self.input_reader.read_file(self.args.dev_src),
self.output_reader.read_file(self.args.dev_trg),
self.args.max_src_len, self.args.max_trg_len,
)
assert len(self.dev_src) == len(self.dev_trg)


def remove_long_sents(self, src_sents, trg_sents, max_src_len, max_trg_len):
filtered_src_sents, filtered_trg_sents = [], []
for src_sent, trg_sent in zip(src_sents, trg_sents):
if (max_src_len is None or len(src_sent) <= max_src_len) and (max_trg_len is None or len(trg_sent) <= max_trg_len):
filtered_src_sents.append(src_sent)
filtered_trg_sents.append(trg_sent)
if max_src_len or max_trg_len:
print("> removed %s out of %s sentences that were too long." % (len(src_sents)-len(filtered_src_sents),len(src_sents)))
return filtered_src_sents, filtered_trg_sents

def run_epoch(self):
self.logger.new_epoch()
Expand Down

0 comments on commit 60c3cfd

Please sign in to comment.