Skip to content

Commit

Permalink
Add support for epoch-based training (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jun 6, 2024
1 parent 992fdf1 commit 69cb3db
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ class InstructionFinetuneConfig:

# Regime
max_num_steps: int = 5000
"""The maximum number of training steps."""
"""The maximum number of steps to train for."""

max_num_data_epochs: Optional[int] = None
"""The maximum number of data epochs to train for."""

checkpoint_every_n_steps: int = 1000
"""The step interval at which to checkpoint."""
Expand Down Expand Up @@ -369,6 +372,7 @@ def load_instruction_finetuner(
max_gradient_norm=config.max_gradient_norm,
data_readers=data_readers,
max_num_steps=config.max_num_steps,
max_num_data_epochs=config.max_num_data_epochs,
checkpoint_manager=checkpoint_manager,
checkpoint_every_n_steps=config.checkpoint_every_n_steps,
keep_last_n_checkpoints=config.keep_last_n_checkpoints,
Expand Down
6 changes: 5 additions & 1 deletion src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ class Wav2Vec2AsrTrainConfig:

# Regime
max_num_steps: int = 20_000
"""The maximum number of training steps."""
"""The maximum number of steps to train for."""

max_num_data_epochs: Optional[int] = None
"""The maximum number of data epochs to train for."""

freeze_encoder_for_n_steps: int = 10_000
"""The encoder will be frozen for this number of steps."""
Expand Down Expand Up @@ -407,6 +410,7 @@ def load_wav2vec2_asr_trainer(
max_gradient_norm=config.max_gradient_norm,
data_readers=data_readers,
max_num_steps=config.max_num_steps,
max_num_data_epochs=config.max_num_data_epochs,
validate_after_n_steps=config.validate_after_n_steps,
validate_every_n_steps=config.validate_every_n_steps,
checkpoint_manager=checkpoint_manager,
Expand Down

0 comments on commit 69cb3db

Please sign in to comment.