Skip to content

Commit

Permalink
ability to checkpoint when reaching certain number of updates (#282)
Browse files Browse the repository at this point in the history
* ability to checkpoint when reaching certain number of updates
  • Loading branch information
alexeib committed May 23, 2018
1 parent e560a12 commit bd110fd
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 32 deletions.
6 changes: 6 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
Expand Down
1 change: 1 addition & 0 deletions fairseq/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def log(self, stats):

def print(self, stats):
"""Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True)

Expand Down
27 changes: 22 additions & 5 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import contextlib
import logging
import os
import re
import torch
import traceback

Expand Down Expand Up @@ -316,11 +317,11 @@ def buffered_arange(max):


def convert_padding_direction(
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
Expand Down Expand Up @@ -356,3 +357,19 @@ def clip_grad_norm_(tensor, max_norm):
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)


def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp = re.compile(pattern)
files = os.listdir(path)

entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
18 changes: 13 additions & 5 deletions scripts/average_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ def average_checkpoints(inputs):
return new_state


def last_n_checkpoints(paths, n):
def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1
path = paths[0]
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)

entries = []
Expand All @@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
'produce a new checkpoint',
)

parser.add_argument(
Expand All @@ -95,19 +98,24 @@ def main():
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
'path.',
)
parser.add_argument(
'--num',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those',
)
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args = parser.parse_args()
print(args)

if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num)
args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints)
print('averaging checkpoints: ', args.inputs)

new_state = average_checkpoints(args.inputs)
Expand Down
69 changes: 47 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths


def main(args):

if args.max_tokens is None:
args.max_tokens = 6000

Expand Down Expand Up @@ -82,26 +82,22 @@ def main(args):
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next(train_dataloader), epoch)
train(args, trainer, next(train_dataloader), epoch, dataset)

# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
first_val_loss = val_loss(args, trainer, dataset, epoch)

# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)

# save checkpoint
if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss)
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)

epoch += 1
train_meter.stop()
Expand All @@ -120,7 +116,7 @@ def load_dataset(args, splits):
return dataset


def train(args, trainer, itr, epoch):
def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch."""

# Set seed based on args.seed and the epoch number so that we get
Expand Down Expand Up @@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if i == 0:
trainer.get_meter('wps').reset()

if trainer.get_num_updates() >= max_update:
num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)

if num_updates >= max_update:
break

# log end-of-epoch stats
Expand Down Expand Up @@ -202,7 +203,7 @@ def get_training_stats(trainer):
return stats


def validate(args, trainer, dataset, subset, epoch):
def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
"""Evaluate the model on the validation set and return the average loss."""

# Initialize dataloader
Expand Down Expand Up @@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
for sample in progress:
log_output = trainer.valid_step(sample)

# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
if verbose:
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)

# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg

if num_updates is not None:
stats['num_updates'] = num_updates

progress.print(stats)

return stats['valid_loss']
Expand All @@ -273,16 +279,33 @@ def get_perplexity(loss):
return float('inf')


def save_checkpoint(trainer, args, epoch, val_loss=None):
def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets]
return losses[0] if len(losses) > 0 else None


def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
extra_state = {
'epoch': epoch,
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
}

if not args.no_epoch_checkpoints:
if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)

assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
Expand Down Expand Up @@ -317,9 +340,11 @@ def load_checkpoint(args, trainer, train_dataloader):

if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main

distributed_main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main

multiprocessing_main(args)
else:
main(args)

0 comments on commit bd110fd

Please sign in to comment.