diff --git a/fairseq/options.py b/fairseq/options.py index 33ffbaf64b..8d80eb050a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -201,6 +201,12 @@ def add_checkpoint_args(parser): help='filename in save-dir from which to load checkpoint') group.add_argument('--save-interval', type=int, default=1, metavar='N', help='save a checkpoint every N epochs') + group.add_argument('--save-interval-updates', type=int, metavar='N', + help='if specified, saves best/last checkpoint every this many updates. ' + 'will also validate before saving to determine if val loss is better') + group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N', + help='if --save-interval-updates is specified, keep the last this many checkpoints' + ' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt') group.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints') group.add_argument('--no-epoch-checkpoints', action='store_true', diff --git a/fairseq/progress_bar.py b/fairseq/progress_bar.py index e89e0bd28b..3da6255c66 100644 --- a/fairseq/progress_bar.py +++ b/fairseq/progress_bar.py @@ -117,6 +117,7 @@ def log(self, stats): def print(self, stats): """Print end-of-epoch stats.""" + self.stats = stats stats = self._format_stats(self.stats, epoch=self.epoch) print(json.dumps(stats), flush=True) diff --git a/fairseq/utils.py b/fairseq/utils.py index 9e178886f6..84380e4515 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -9,6 +9,7 @@ import contextlib import logging import os +import re import torch import traceback @@ -316,11 +317,11 @@ def buffered_arange(max): def convert_padding_direction( - src_tokens, - src_lengths, - padding_idx, - right_to_left=False, - left_to_right=False, + src_tokens, + src_lengths, + padding_idx, + right_to_left=False, + left_to_right=False, ): assert right_to_left ^ left_to_right pad_mask = src_tokens.eq(padding_idx) @@ -356,3 +357,19 @@ def clip_grad_norm_(tensor, max_norm): def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) + + +def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): + """ retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to + the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending + order """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = int(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index 02951eb631..873cfe01c2 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -62,10 +62,13 @@ def average_checkpoints(inputs): return new_state -def last_n_checkpoints(paths, n): +def last_n_checkpoints(paths, n, update_based): assert len(paths) == 1 path = paths[0] - pt_regexp = re.compile(r'checkpoint(\d+)\.pt') + if update_based: + pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') + else: + pt_regexp = re.compile(r'checkpoint(\d+)\.pt') files = os.listdir(path) entries = [] @@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n): def main(): parser = argparse.ArgumentParser( description='Tool to average the params of input checkpoints to ' - 'produce a new checkpoint', + 'produce a new checkpoint', ) parser.add_argument( @@ -95,7 +98,7 @@ def main(): required=True, metavar='FILE', help='Write the new checkpoint containing the averaged weights to this ' - 'path.', + 'path.', ) parser.add_argument( '--num', @@ -103,11 +106,16 @@ def main(): help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 'and average last num of those', ) + parser.add_argument( + '--update-based-checkpoints', + action='store_true', + help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints' + ) args = parser.parse_args() print(args) if args.num is not None: - args.inputs = last_n_checkpoints(args.inputs, args.num) + args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints) print('averaging checkpoints: ', args.inputs) new_state = average_checkpoints(args.inputs) diff --git a/train.py b/train.py index 1582d19642..45569deed3 100644 --- a/train.py +++ b/train.py @@ -15,10 +15,10 @@ from fairseq.fp16_trainer import FP16Trainer from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter +from fairseq.utils import checkpoint_paths def main(args): - if args.max_tokens is None: args.max_tokens = 6000 @@ -82,26 +82,22 @@ def main(args): max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() + first_val_loss = None train_meter = StopwatchMeter() train_meter.start() while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: # train for one epoch - train(args, trainer, next(train_dataloader), epoch) + train(args, trainer, next(train_dataloader), epoch, dataset) - # evaluate on validate set - first_val_loss = None if epoch % args.validate_interval == 0: - for k, subset in enumerate(args.valid_subset.split(',')): - val_loss = validate(args, trainer, dataset, subset, epoch) - if k == 0: - first_val_loss = val_loss + first_val_loss = val_loss(args, trainer, dataset, epoch) # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch, first_val_loss) # save checkpoint if not args.no_save and epoch % args.save_interval == 0: - save_checkpoint(trainer, args, epoch, first_val_loss) + save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss) epoch += 1 train_meter.stop() @@ -120,7 +116,7 @@ def load_dataset(args, splits): return dataset -def train(args, trainer, itr, epoch): +def train(args, trainer, itr, epoch, dataset): """Train the model for one epoch.""" # Set seed based on args.seed and the epoch number so that we get @@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch): if i == 0: trainer.get_meter('wps').reset() - if trainer.get_num_updates() >= max_update: + num_updates = trainer.get_num_updates() + if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0: + first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates) + save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss) + + if num_updates >= max_update: break # log end-of-epoch stats @@ -202,7 +203,7 @@ def get_training_stats(trainer): return stats -def validate(args, trainer, dataset, subset, epoch): +def validate(args, trainer, dataset, subset, epoch, num_updates, verbose): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader @@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch): for sample in progress: log_output = trainer.valid_step(sample) - # log mid-validation stats - stats = get_valid_stats(trainer) - for k, v in log_output.items(): - if k in ['loss', 'nll_loss', 'sample_size']: - continue - extra_meters[k].update(v) - stats[k] = extra_meters[k].avg - progress.log(stats) + if verbose: + # log mid-validation stats + stats = get_valid_stats(trainer) + for k, v in log_output.items(): + if k in ['loss', 'nll_loss', 'sample_size']: + continue + extra_meters[k].update(v) + stats[k] = extra_meters[k].avg + progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg + + if num_updates is not None: + stats['num_updates'] = num_updates + progress.print(stats) return stats['valid_loss'] @@ -273,16 +279,33 @@ def get_perplexity(loss): return float('inf') -def save_checkpoint(trainer, args, epoch, val_loss=None): +def val_loss(args, trainer, dataset, epoch, num_updates=None): + # evaluate on validate set + subsets = args.valid_subset.split(',') + # we want to validate all subsets so the results get printed out, but return only the first + losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets] + return losses[0] if len(losses) > 0 else None + + +def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss): extra_state = { 'epoch': epoch, 'val_loss': val_loss, 'wall_time': trainer.get_meter('wall').elapsed_time, } - if not args.no_epoch_checkpoints: + if end_of_epoch and not args.no_epoch_checkpoints: epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch)) trainer.save_checkpoint(epoch_filename, extra_state) + elif not end_of_epoch and args.keep_interval_updates > 0: + checkpoint_filename = os.path.join(args.save_dir, + 'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates())) + trainer.save_checkpoint(checkpoint_filename, extra_state) + # remove old checkpoints + checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') + # checkpoints are sorted in descending order + for old_chk in checkpoints[args.keep_interval_updates:]: + os.remove(old_chk) assert val_loss is not None if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best: @@ -317,9 +340,11 @@ def load_checkpoint(args, trainer, train_dataloader): if args.distributed_port > 0 or args.distributed_init_method is not None: from distributed_train import main as distributed_main + distributed_main(args) elif args.distributed_world_size > 1: from multiprocessing_train import main as multiprocessing_main + multiprocessing_main(args) else: main(args)