Skip to content

Commit

Permalink
code adaptation/changes according to the commits on Nov 4, 2020
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Nov 7, 2020
1 parent 1b19c66 commit e5c9d31
Showing 1 changed file with 67 additions and 30 deletions.
97 changes: 67 additions & 30 deletions espresso/speech_train.py
Expand Up @@ -49,8 +49,9 @@ def main(cfg: DictConfig) -> None:

utils.import_user_module(cfg.common)

assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \
"Must specify batch size either with --max-tokens or --batch-size"
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
metrics.reset()

np.random.seed(cfg.common.seed)
Expand All @@ -71,19 +72,21 @@ def main(cfg: DictConfig) -> None:
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)

assert cfg.criterion, "Please specify criterion to train a model"

# Build model and criterion
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(model)
logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__))
logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__))
logger.info("task: {}".format(task.__class__.__name__))
logger.info("model: {}".format(model.__class__.__name__))
logger.info("criterion: {})".format(criterion.__class__.__name__))
logger.info(
"criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__)
"num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
)
logger.info("num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
))

# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:
Expand All @@ -101,11 +104,17 @@ def main(cfg: DictConfig) -> None:
else:
trainer = MegatronTrainer(cfg, task, model, criterion)

logger.info("training on {} devices (GPUs/TPUs)".format(cfg.distributed_training.distributed_world_size))
logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
cfg.dataset.max_tokens,
cfg.dataset.batch_size,
))
logger.info(
"training on {} devices (GPUs/TPUs)".format(
cfg.distributed_training.distributed_world_size
)
)
logger.info(
"max tokens per GPU = {} and batch size per GPU = {}".format(
cfg.dataset.max_tokens,
cfg.dataset.batch_size,
)
)

# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
Expand All @@ -120,10 +129,7 @@ def main(cfg: DictConfig) -> None:
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
while (
lr > cfg.optimization.min_lr
and epoch_itr.next_epoch_idx <= max_epoch
):
while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop:
Expand Down Expand Up @@ -161,14 +167,20 @@ def is_better(a, b):
else:
should_stop_early.num_runs += 1
if should_stop_early.num_runs >= cfg.checkpoint.patience:
logger.info("early stop since valid performance hasn't improved for last {} runs".format(cfg.checkpoint.patience))
logger.info(
"early stop since valid performance hasn't improved for last {} runs".format(
cfg.checkpoint.patience
)
)
return True
else:
return False


@metrics.aggregate("train")
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]:
def train(
cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
) -> Tuple[List[Optional[float]], bool]:
"""Train the model for one epoch and return validation losses."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
Expand All @@ -189,7 +201,9 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr)
log_interval=cfg.common.log_interval,
epoch=epoch_itr.epoch,
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
cfg.common.tensorboard_logdir
if distributed_utils.is_master(cfg.distributed_training)
else None
),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
Expand Down Expand Up @@ -244,7 +258,14 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr)
return valid_losses, should_stop


def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]:
def validate_and_save(
cfg: DictConfig,
trainer: Trainer,
task: tasks.FairseqTask,
epoch_itr,
valid_subsets: List[str],
end_of_epoch: bool,
) -> Tuple[List[Optional[float]], bool]:
num_updates = trainer.get_num_updates()
max_update = cfg.optimization.max_update or math.inf
do_save = (
Expand Down Expand Up @@ -279,14 +300,17 @@ def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask
or num_updates >= max_update
or (
cfg.optimization.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours
and trainer.cumulative_training_time() / (60 * 60)
> cfg.optimization.stop_time_hours
)
)

# Save checkpoint
if do_save or should_stop:
logger.info("begin save checkpoint")
checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0])
checkpoint_utils.save_checkpoint(
cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
)

return valid_losses, should_stop

Expand All @@ -296,7 +320,13 @@ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
return stats


def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]:
def validate(
cfg: DictConfig,
trainer: Trainer,
task: tasks.FairseqTask,
epoch_itr,
subsets: List[str],
) -> List[Optional[float]]:
"""Evaluate the model on the validation set(s) and return the losses."""

if cfg.dataset.fixed_validation_seed is not None:
Expand All @@ -306,7 +336,7 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i
trainer.begin_valid_epoch(epoch_itr.epoch)
valid_losses = []
for subset in subsets:
logger.info("begin validation on '{}' subset".format(subset))
logger.info('begin validation on "{}" subset'.format(subset))

# Initialize data iterator
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
Expand All @@ -319,7 +349,9 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i
epoch=epoch_itr.epoch,
prefix=f"valid on '{subset}' subset",
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
cfg.common.tensorboard_logdir
if distributed_utils.is_master(cfg.distributed_training)
else None
),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
Expand All @@ -341,13 +373,16 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i
return valid_losses


def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]:
def get_valid_stats(
cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
) -> Dict[str, Any]:
stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best"):
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric]
checkpoint_utils.save_checkpoint.best,
stats[cfg.checkpoint.best_checkpoint_metric],
)
return stats

Expand All @@ -359,7 +394,9 @@ def print_options_meaning_changes(args):
logger.info("--max-tokens is the maximum number of input frames in a batch")


def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None:
def cli_main(
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
print_options_meaning_changes(args)
Expand Down

0 comments on commit e5c9d31

Please sign in to comment.