diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 9391ba72e5..ffe14b661f 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -49,8 +49,9 @@ def main(cfg: DictConfig) -> None: utils.import_user_module(cfg.common) - assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ - "Must specify batch size either with --max-tokens or --batch-size" + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) @@ -71,19 +72,21 @@ def main(cfg: DictConfig) -> None: for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) + assert cfg.criterion, "Please specify criterion to train a model" + # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) - logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {})".format(criterion.__class__.__name__)) logger.info( - "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) + "num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) ) - logger.info("num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: @@ -101,11 +104,17 @@ def main(cfg: DictConfig) -> None: else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info("training on {} devices (GPUs/TPUs)".format(cfg.distributed_training.distributed_world_size)) - logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( - cfg.dataset.max_tokens, - cfg.dataset.batch_size, - )) + logger.info( + "training on {} devices (GPUs/TPUs)".format( + cfg.distributed_training.distributed_world_size + ) + ) + logger.info( + "max tokens per GPU = {} and batch size per GPU = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + ) + ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator @@ -120,10 +129,7 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while ( - lr > cfg.optimization.min_lr - and epoch_itr.next_epoch_idx <= max_epoch - ): + while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -161,14 +167,20 @@ def is_better(a, b): else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= cfg.checkpoint.patience: - logger.info("early stop since valid performance hasn't improved for last {} runs".format(cfg.checkpoint.patience)) + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + cfg.checkpoint.patience + ) + ) return True else: return False @metrics.aggregate("train") -def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: +def train( + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr +) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -189,7 +201,9 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -244,7 +258,14 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) return valid_losses, should_stop -def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: +def validate_and_save( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + valid_subsets: List[str], + end_of_epoch: bool, +) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf do_save = ( @@ -279,14 +300,17 @@ def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask or num_updates >= max_update or ( cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours + and trainer.cumulative_training_time() / (60 * 60) + > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint( + cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + ) return valid_losses, should_stop @@ -296,7 +320,13 @@ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: return stats -def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: +def validate( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + subsets: List[str], +) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: @@ -306,7 +336,7 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: - logger.info("begin validation on '{}' subset".format(subset)) + logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) @@ -319,7 +349,9 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -341,13 +373,16 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i return valid_losses -def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: +def get_valid_stats( + cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any] +) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], ) return stats @@ -359,7 +394,9 @@ def print_options_meaning_changes(args): logger.info("--max-tokens is the maximum number of input frames in a batch") -def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args)