Skip to content

Commit

Permalink
Improve shuffle handling in datasets (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jun 6, 2024
1 parent 55e7a14 commit 992fdf1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 23 deletions.
31 changes: 23 additions & 8 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def create_reader(
dtype: DataType = torch.float32,
min_audio_len: int = 1,
normalize_audio: bool = False,
shuffle_window_size: int = 1,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
Expand All @@ -80,9 +81,14 @@ def create_reader(
this value will be dropped.
:param normalize_audio:
If ``True``, normalizes audio to have zero mean and unit variance.
:param shuffle_window_size:
The size of the shuffle window. If ``1``, no shuffling is performed;
if ``0``, performs true shuffling by loading the entire dataset.
:param example_shuffle_window:
The size of the sliding window for shuffling examples. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param batch_shuffle_window:
The size of the sliding window for shuffling batches. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param num_accumulate:
The number of batches to accumulate in each iteration. Typically
used with gradient accumulation during training.
Expand Down Expand Up @@ -139,7 +145,8 @@ def create_reader(
dtype: DataType = torch.float32,
min_audio_len: int = 1,
normalize_audio: bool = False,
shuffle_window_size: int = 1,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
Expand All @@ -162,12 +169,17 @@ def create_reader(

builder = read_sequence(manifest)

# Shuffle the entire manifest. Must be consistent across all processes.
builder.shuffle(shuffle_window=0, seed=seed)
# Shuffle examples. Must be consistent across all processes.
if example_shuffle_window != 1:
builder.shuffle(example_shuffle_window, seed=seed)

seed += 1

# Shard.
builder.shard(gang.rank, gang.size, allow_uneven=True)

seed += gang.rank

# Bucket by audio length.
bucket_sizes = create_bucket_sizes(
max_num_elements=max_num_elements,
Expand All @@ -185,7 +197,10 @@ def create_reader(
)

# Shuffle buckets.
builder.shuffle(shuffle_window_size, seed=seed + gang.rank)
if batch_shuffle_window != 1:
builder.shuffle(batch_shuffle_window, seed=seed)

seed += 1

# Memory map audio files.
file_mapper = FileMapper(root_data_dir, cached_fd_count=cached_fd_count)
Expand Down
29 changes: 20 additions & 9 deletions src/fairseq2/datasets/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def create_reader(
max_seq_len: int,
max_num_tokens: int,
*,
shuffle_window_size: int = 1,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
Expand All @@ -63,9 +64,14 @@ def create_reader(
this value will be dropped.
:param max_num_tokens:
The maximum number of tokens in each batch.
:param shuffle_window_size:
The size of the shuffle window. If ``1``, no shuffling is performed;
if ``0``, performs true shuffling by loading the entire dataset.
:param example_shuffle_window:
The size of the sliding window for shuffling examples. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param batch_shuffle_window:
The size of the sliding window for shuffling batches. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param num_accumulate:
The number of batches to accumulate in each iteration. Typically
used with gradient accumulation during training.
Expand Down Expand Up @@ -109,22 +115,26 @@ def create_reader(
max_seq_len: int,
max_num_tokens: int,
*,
shuffle_window_size: int = 1,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
**extras: Any,
) -> DataPipelineReader[SequenceBatch]:
builder = list_files(self._data_dir, pattern="*.jsonl")

# Shuffle the files. Must be consistent across all processes.
builder.shuffle(shuffle_window=0, seed=seed)
# Shuffle files. Must be consistent across all processes.
if example_shuffle_window != 1:
builder.shuffle(shuffle_window=0, seed=seed)

seed += 1

builder.yield_from(partial(self._read_jsonl, tokenizer=tokenizer))

builder.shuffle(shuffle_window_size, seed=seed)
# Shuffle examples.
if example_shuffle_window != 1:
builder.shuffle(example_shuffle_window, seed=seed)

seed += 1

Expand All @@ -143,7 +153,8 @@ def create_reader(
)

# Shuffle buckets.
builder.shuffle(shuffle_window_size, seed=seed)
if batch_shuffle_window != 1:
builder.shuffle(batch_shuffle_window, seed=seed)

seed += 1

Expand Down
10 changes: 7 additions & 3 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ class InstructionFinetuneConfig:
max_num_tokens: int = 8192 * 2
"""The maximum number of tokens per batch."""

shuffle_window_size: int = 10_000
"""The size of the sliding data shuffle window."""
example_shuffle_window: int = 10_000
"""The size of the sliding window for shuffling examples."""

batch_shuffle_window: int = 100
"""The size of the sliding window for shuffling batches."""

num_prefetch: int = 4
"""The number of batches to prefetch in background."""
Expand Down Expand Up @@ -198,7 +201,8 @@ def load_instruction_finetuner(
gang=dp_gang,
max_seq_len=config.max_seq_len,
max_num_tokens=config.max_num_tokens,
shuffle_window_size=config.shuffle_window_size,
example_shuffle_window=config.example_shuffle_window,
batch_shuffle_window=config.batch_shuffle_window,
num_accumulate=config.gradient_accumulation,
num_prefetch=config.num_prefetch,
seed=config.seed,
Expand Down
10 changes: 7 additions & 3 deletions src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ class Wav2Vec2AsrTrainConfig:
normalize_audio: bool = False
"""If ``True``, normalizes audio to have zero mean and unit variance."""

shuffle_window_size: int = 1000
"""The size of the sliding data shuffle window."""
example_shuffle_window: int = 0
"""The size of the sliding window for shuffling examples."""

batch_shuffle_window: int = 1000
"""The size of the sliding window for shuffling batches."""

num_prefetch: int = 4
"""The number of batches to prefetch in background."""
Expand Down Expand Up @@ -220,7 +223,8 @@ def load_wav2vec2_asr_trainer(
max_audio_len=config.max_audio_len,
max_num_elements=config.max_num_elements,
normalize_audio=config.normalize_audio,
shuffle_window_size=config.shuffle_window_size,
example_shuffle_window=config.example_shuffle_window,
batch_shuffle_window=config.batch_shuffle_window,
num_accumulate=config.gradient_accumulation,
num_prefetch=config.num_prefetch,
seed=config.seed,
Expand Down

0 comments on commit 992fdf1

Please sign in to comment.