Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[Feature] support load model/trainer states from s3 (#702)
Browse files Browse the repository at this point in the history
* support load model/trainer states from s3

* update training script

* bug fix

* fix lint

* rename

* move functions to utils/data.py

* fix lint and rename

* fix import

* fix lint

* make lint happy..
  • Loading branch information
eric-haibin-lin committed May 10, 2019
1 parent 0cdb38d commit e2b2766
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 119 deletions.
16 changes: 10 additions & 6 deletions scripts/bert/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from gluonnlp.metric import MaskedAccuracy

__all__ = ['get_model_loss', 'get_pretrain_dataset', 'get_dummy_dataloader',
'save_params', 'evaluate', 'forward', 'split_and_load', 'get_argparser']
'save_parameters', 'save_states', 'evaluate', 'forward', 'split_and_load',
'get_argparser']

def get_model_loss(ctx, model, pretrained, dataset_name, dtype, ckpt_dir=None, start_step=None):
"""Get model for pre-training."""
Expand All @@ -46,7 +47,7 @@ def get_model_loss(ctx, model, pretrained, dataset_name, dtype, ckpt_dir=None, s

if ckpt_dir and start_step:
param_path = os.path.join(ckpt_dir, '%07d.params'%start_step)
model.load_parameters(param_path, ctx=ctx)
nlp.utils.load_parameters(model, param_path, ctx=ctx)
logging.info('Loading step %d checkpoints from %s.', start_step, param_path)

model.hybridize(static_alloc=True)
Expand Down Expand Up @@ -126,13 +127,16 @@ def __iter__(self):

return DummyIter(data_batch)

def save_params(step_num, model, trainer, ckpt_dir):
def save_parameters(step_num, model, ckpt_dir):
"""Save the model parameter, marked by step_num."""
param_path = os.path.join(ckpt_dir, '%07d.params'%step_num)
trainer_path = os.path.join(ckpt_dir, '%07d.states'%step_num)
logging.info('[step %d] Saving checkpoints to %s, %s.',
step_num, param_path, trainer_path)
logging.info('[step %d] Saving model params to %s.', step_num, param_path)
nlp.utils.save_parameters(model, param_path)

def save_states(step_num, trainer, ckpt_dir, local_rank=0):
"""Save the trainer states, marked by step_num."""
trainer_path = os.path.join(ckpt_dir, '%07d.states.%02d'%(step_num, local_rank))
logging.info('[step %d] Saving trainer states to %s.', step_num, trainer_path)
nlp.utils.save_states(trainer, trainer_path)

def log(begin_time, running_num_tks, running_mlm_loss, running_nsp_loss, step_num,
Expand Down
16 changes: 10 additions & 6 deletions scripts/bert/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from utils import profile
from fp16_utils import FP16Trainer
from pretraining_utils import get_model_loss, get_pretrain_dataset, get_dummy_dataloader
from pretraining_utils import save_params, log, evaluate, forward, split_and_load, get_argparser
from pretraining_utils import log, evaluate, forward, split_and_load, get_argparser
from pretraining_utils import save_parameters, save_states

# arg parser
parser = get_argparser()
Expand Down Expand Up @@ -107,9 +108,9 @@ def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store):
fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale)

if args.start_step:
state_path = os.path.join(args.ckpt_dir, '%07d.states' % args.start_step)
state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, 0))
logging.info('Loading trainer state from %s', state_path)
trainer.load_states(state_path)
nlp.utils.load_states(trainer, state_path)

accumulate = args.accumulate
num_train_steps = args.num_steps
Expand Down Expand Up @@ -206,10 +207,13 @@ def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store):

# saving checkpoints
if (step_num + 1) % args.ckpt_interval == 0 \
and (batch_num + 1) % accumulate == 0:
save_params(step_num, model, trainer, args.ckpt_dir)
and (batch_num + 1) % accumulate == 0 and store.rank == 0:
save_states(step_num, trainer, args.ckpt_dir)
save_parameters(step_num, model, args.ckpt_dir)
batch_num += 1
save_params(step_num, model, trainer, args.ckpt_dir)
if store.rank == 0:
save_states(step_num, trainer, args.ckpt_dir)
save_parameters(step_num, model, args.ckpt_dir)
mx.nd.waitall()
train_end_time = time.time()
logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Expand Down
22 changes: 15 additions & 7 deletions scripts/bert/run_pretraining_hvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from utils import profile
from fp16_utils import FP16Trainer
from pretraining_utils import get_model_loss, get_pretrain_dataset, get_dummy_dataloader
from pretraining_utils import save_params, split_and_load, log, evaluate, forward, get_argparser
from pretraining_utils import split_and_load, log, evaluate, forward, get_argparser
from pretraining_utils import save_parameters, save_states

# parser
parser = get_argparser()
Expand All @@ -64,6 +65,7 @@
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()
is_master_node = rank == local_rank
if not args.use_avg_len and hvd.size() > 1:
logging.info('Specifying --use-avg-len and setting --batch_size with the '
'target number of tokens would help improve training throughput.')
Expand Down Expand Up @@ -93,7 +95,9 @@ def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx):
loss_scaler_params=loss_scale_param)

if args.start_step:
trainer.load_states(os.path.join(args.ckpt_dir, '%07d.states'%args.start_step))
state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
logging.info('Loading trainer state from %s', state_path)
nlp.utils.load_states(trainer, state_path)

accumulate = args.accumulate
num_train_steps = args.num_steps
Expand Down Expand Up @@ -191,14 +195,18 @@ def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx):
nsp_metric.reset_local()

# saving checkpoints
if (step_num + 1) % args.ckpt_interval == 0 \
and (batch_num + 1) % accumulate == 0 and local_rank == 0:
save_params(step_num, model, trainer, args.ckpt_dir)
if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
if is_master_node:
save_states(step_num, trainer, args.ckpt_dir, local_rank)
if local_rank == 0:
save_parameters(step_num, model, args.ckpt_dir)

batch_num += 1

if local_rank == 0:
save_params(step_num, model, trainer, args.ckpt_dir)
if is_master_node:
save_states(step_num, trainer, args.ckpt_dir, local_rank)
if local_rank == 0:
save_parameters(step_num, model, args.ckpt_dir)
mx.nd.waitall()
train_end_time = time.time()
logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Expand Down
6 changes: 3 additions & 3 deletions src/gluonnlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
# pylint: disable=wildcard-import, arguments-differ
"""Module for utility functions."""

from . import (parallel, parameter, data)
from . import (parallel, parameter, files)

from .parallel import *
from .parameter import *
from .data import *
from .files import *

__all__ = parallel.__all__ + parameter.__all__ + data.__all__
__all__ = parallel.__all__ + parameter.__all__ + files.__all__
47 changes: 0 additions & 47 deletions src/gluonnlp/utils/data.py

This file was deleted.

90 changes: 90 additions & 0 deletions src/gluonnlp/utils/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation
"""Utility functions for files."""

__all__ = ['mkdir']

import os
import warnings
import logging
import tempfile
from .. import _constants as C

def mkdir(dirname):
"""Create directory.
Parameters
----------
dirname : str
The name of the target directory to create.
"""
if C.S3_PREFIX in dirname:
warnings.warn('Directory %s is not created because it contains %s'
%(dirname, C.S3_PREFIX))
return
dirname = os.path.expanduser(dirname)
if not os.path.exists(dirname):
try:
os.makedirs(dirname)
except OSError as e:
# errno 17 means the file already exists
if e.errno != 17:
raise e

class _TempFilePath(object):
"""A TempFilePath that provides a path to a temporarily file, and automatically
cleans up the temp file at exit.
"""
def __init__(self):
self.temp_dir = os.path.join(tempfile.gettempdir(), str(hash(os.times())))
if not os.path.exists(self.temp_dir):
os.makedirs(self.temp_dir)

def __enter__(self):
self.temp_path = os.path.join(self.temp_dir, str(hash(os.times())))
return self.temp_path

def __exit__(self, exec_type, exec_value, traceback):
os.remove(self.temp_path)

def _transfer_file_s3(filename, s3_filename, upload=True):
"""Transfer a file between S3 and local file system."""
try:
import boto3
except ImportError:
raise ImportError('boto3 is required to support s3 URI. Please install'
'boto3 via `pip install boto3`')
# parse s3 uri
prefix_len = len(C.S3_PREFIX)
bucket_idx = s3_filename[prefix_len:].index('/') + prefix_len
bucket_name = s3_filename[prefix_len:bucket_idx]

# filename after the bucket, excluding '/'
key_name = s3_filename[bucket_idx + 1:]

log_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.INFO)
# upload to s3
s3 = boto3.client('s3')
if upload:
s3.upload_file(filename, bucket_name, key_name)
else:
s3.download_file(bucket_name, key_name, filename)
logging.getLogger().setLevel(log_level)

0 comments on commit e2b2766

Please sign in to comment.