diff --git a/metaseq/data/jsonl_dataset.py b/metaseq/data/jsonl_dataset.py index 117bf3c5d..27bcc9f66 100644 --- a/metaseq/data/jsonl_dataset.py +++ b/metaseq/data/jsonl_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +from io import TextIOWrapper import json import logging import mmap @@ -132,22 +133,17 @@ def _get_subshard_id(self): # and then wraps around if the epoch id goes beyond the data_subshard_count return (self.epoch - 1) % self.data_subshard_count - def _build_index(self, path: str): + def _build_index(self, file_path: str): """Build index of start positions of each line.""" - logger.info(f"Building index for file: {path}") - f = self._get_mmap() - f.seek(0) - offsets = [] - cur = 0 - line_num = 0 - while True: - line = f.readline() - if line == b"": - break - offsets.append(cur) - cur += len(line) - line_num += 1 - return offsets + logger.info(f"Building index for file: {file_path}") + file: TextIOWrapper = self._get_mmap() + + offsets = [0] + for _ in iter(file.readline, b""): + offsets.append(file.tell()) + + # return all offsets except the last one, which is the end of the file + return offsets[:-1] def __setstate__(self, state): self.__dict__ = state