Skip to content

Commit

Permalink
Merge pull request #108 from neulab/generalize-training-filtering
Browse files Browse the repository at this point in the history
Generalized training time filtering
  • Loading branch information
neubig committed Jun 27, 2017
2 parents a27c6c0 + 9ab8a0e commit 1207aac
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
4 changes: 4 additions & 0 deletions examples/standard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ defaults:
train_trg: examples/data/train.en
dev_src: examples/data/dev.ja
dev_trg: examples/data/dev.en
train_filters:
- type: length
min: 1
max: 50
decode:
src_file: examples/data/test.ja
evaluate:
Expand Down
28 changes: 16 additions & 12 deletions xnmt/xnmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from model_params import *
from loss_tracker import *
from serializer import *
from preproc import SentenceFilterer
from options import Option, OptionParser, general_options

'''
Expand All @@ -30,6 +31,8 @@
Option("train_trg"),
Option("dev_src"),
Option("dev_trg"),
Option("train_filters", list, required=False, help_str="Specify filtering criteria for the training data"),
Option("dev_filters", list, required=False, help_str="Specify filtering criteria for the development data"),
Option("max_src_len", int, required=False, help_str="Remove sentences from training/dev data that are longer than this on the source side"),
Option("max_trg_len", int, required=False, help_str="Remove sentences from training/dev data that are longer than this on the target side"),
Option("max_num_train_sents", int, required=False, help_str="Load only the first n sentences from the training data"),
Expand Down Expand Up @@ -172,11 +175,11 @@ def create_model(self):


def read_data(self):
train_filters = SentenceFilterer.from_spec(self.args.train_filters)
self.train_src, self.train_trg = \
self.remove_long_sents(self.input_reader.read_file(self.args.train_src, max_num=self.args.max_num_train_sents),
self.output_reader.read_file(self.args.train_trg, max_num=self.args.max_num_train_sents),
self.args.max_src_len, self.args.max_trg_len,
)
self.filter_sents(self.input_reader.read_file(self.args.train_src, max_num=self.args.max_num_train_sents),
self.output_reader.read_file(self.args.train_trg, max_num=self.args.max_num_train_sents),
train_filters)
assert len(self.train_src) == len(self.train_trg)
self.total_train_sent = len(self.train_src)
if self.args.eval_every == None:
Expand All @@ -185,21 +188,22 @@ def read_data(self):
self.input_reader.freeze()
self.output_reader.freeze()

dev_filters = SentenceFilterer.from_spec(self.args.dev_filters)
self.dev_src, self.dev_trg = \
self.remove_long_sents(self.input_reader.read_file(self.args.dev_src),
self.output_reader.read_file(self.args.dev_trg),
self.args.max_src_len, self.args.max_trg_len,
)
self.filter_sents(self.input_reader.read_file(self.args.dev_src),
self.output_reader.read_file(self.args.dev_trg),
dev_filters)
assert len(self.dev_src) == len(self.dev_trg)

def remove_long_sents(self, src_sents, trg_sents, max_src_len, max_trg_len):
def filter_sents(self, src_sents, trg_sents, my_filters):
if len(my_filters) == 0:
return src_sents, trg_sents
filtered_src_sents, filtered_trg_sents = [], []
for src_sent, trg_sent in zip(src_sents, trg_sents):
if (max_src_len is None or len(src_sent) <= max_src_len) and (max_trg_len is None or len(trg_sent) <= max_trg_len):
if all([my_filter.keep((src_sent,trg_sent)) for my_filter in my_filters]):
filtered_src_sents.append(src_sent)
filtered_trg_sents.append(trg_sent)
if max_src_len or max_trg_len:
print("> removed %s out of %s sentences that were too long." % (len(src_sents)-len(filtered_src_sents),len(src_sents)))
print("> removed %s out of %s sentences that didn't pass filters." % (len(src_sents)-len(filtered_src_sents),len(src_sents)))
return filtered_src_sents, filtered_trg_sents

def run_epoch(self):
Expand Down

0 comments on commit 1207aac

Please sign in to comment.