Skip to content

Commit

Permalink
A few changes to support input of large files (#448)
Browse files Browse the repository at this point in the history
* Changes to input_reader

* Removed "self." in "super()"

* Revert to standard loop for file reading

* Added lru_cache
  • Loading branch information
neubig committed Jun 29, 2018
1 parent 6274a12 commit 77bd33d
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions xnmt/input_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import zip_longest
from functools import lru_cache
import ast
from typing import Sequence, Iterator

Expand Down Expand Up @@ -67,12 +68,13 @@ def read_sent(self, line: str) -> xnmt.input.Input:
"""
raise RuntimeError("Input readers must implement the read_sent function")

@lru_cache(maxsize=128)
def count_sents(self, filename):
f = open(filename, encoding='utf-8')
try:
return sum(1 for _ in f)
finally:
f.close()
newlines = 0
f = open(filename, 'r+b')
for line in f:
newlines += 1
return newlines

def iterate_filtered(self, filename, filter_ids=None):
"""
Expand Down Expand Up @@ -399,12 +401,14 @@ def read_parallel_corpus(src_reader, trg_reader, src_file, trg_file,
src_data = []
trg_data = []
if sample_sents:
logger.info(f"Starting to read {sample_sents} parallel sentences of {src_file} and {trg_file}")
src_len = src_reader.count_sents(src_file)
trg_len = trg_reader.count_sents(trg_file)
if src_len != trg_len: raise RuntimeError(f"training src sentences don't match trg sentences: {src_len} != {trg_len}!")
if max_num_sents and max_num_sents < src_len: src_len = trg_len = max_num_sents
filter_ids = np.random.choice(src_len, sample_sents, replace=False)
else:
logger.info(f"Starting to read {src_file} and {trg_file}")
filter_ids = None
src_len, trg_len = 0, 0
src_train_iterator = src_reader.read_sents(src_file, filter_ids)
Expand All @@ -420,10 +424,14 @@ def read_parallel_corpus(src_reader, trg_reader, src_file, trg_file,
src_data.append(src_sent)
trg_data.append(trg_sent)

logger.info(f"Done reading {src_file} and {trg_file}. Packing into batches.")

# Pack batches
if batcher is not None:
src_batches, trg_batches = batcher.pack(src_data, trg_data)
else:
src_batches, trg_batches = src_data, trg_data

logger.info(f"Done packing batches.")

return src_data, trg_data, src_batches, trg_batches

0 comments on commit 77bd33d

Please sign in to comment.