Skip to content

Commit

Permalink
Added option to save checkpoints using Path Manager.
Browse files Browse the repository at this point in the history
Summary: Added option to save checkpoints using Path Manager.

Reviewed By: hudeven

Differential Revision: D17392754

fbshipit-source-id: 4b8e556ef8455a1548e5a083d779ed809cd785be
  • Loading branch information
sujitoc authored and facebook-github-bot committed Oct 12, 2019
1 parent 02b74c5 commit d80ad54
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
29 changes: 24 additions & 5 deletions fairseq/checkpoint_utils.py
Expand Up @@ -65,7 +65,11 @@ def is_better(a, b):
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
try:
from fairseq.fb_pathmgr import fb_pathmgr
fb_pathmgr.copy(checkpoints[0], cp, True)
except (ModuleNotFoundError, ImportError):
shutil.copyfile(checkpoints[0], cp)

write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
Expand Down Expand Up @@ -132,9 +136,17 @@ def load_checkpoint(args, trainer, data_selector=None):

def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(path, "rb") as f:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
Expand Down Expand Up @@ -244,7 +256,14 @@ def save_state(
state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename)

try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(filename, "wb") as f:
torch_persistent_save(state_dict, f)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
torch_persistent_save(state_dict, filename)


def _upgrade_state_dict(state):
Expand Down
8 changes: 7 additions & 1 deletion fairseq/trainer.py
Expand Up @@ -170,7 +170,13 @@ def load_checkpoint(
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = None, [], None

if os.path.exists(filename):
try:
from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename)
except Exception:
bexists = os.path.exists(filename)

if bexists:
state = checkpoint_utils.load_checkpoint_to_cpu(filename)

# load model parameters
Expand Down
11 changes: 11 additions & 0 deletions train.py
Expand Up @@ -19,10 +19,21 @@
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter

fb_pathmgr_registerd = False


def main(args, init_distributed=False):
utils.import_user_module(args)

try:
from fairseq.fb_pathmgr import fb_pathmgr
global fb_pathmgr_registerd
if not fb_pathmgr_registerd:
fb_pathmgr.register()
fb_pathmgr_registerd = True
except (ModuleNotFoundError, ImportError):
pass

assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'

Expand Down

0 comments on commit d80ad54

Please sign in to comment.