Skip to content

Commit

Permalink
Merge pull request microsoft#24 from microsoft/olruwase/checkpoint_lr…
Browse files Browse the repository at this point in the history
…_scheduler

Optional loading optimizer and lr scheduler states
  • Loading branch information
tjruwase committed May 11, 2020
2 parents fc713d9 + 2cb5784 commit 96b2224
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 22 deletions.
28 changes: 23 additions & 5 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,19 +1041,30 @@ def _ensure_directory_exists(self, filename):
if not os.path.exists(dirname):
os.makedirs(dirname)

def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""

load_path, client_states = self._load_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
load_path, client_states = self._load_checkpoint(load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir,
Expand All @@ -1062,7 +1073,12 @@ def load_checkpoint(self, load_dir, tag, load_optimizer_states=True):

return load_path, client_states

def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def _load_checkpoint(self,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):

load_path = self._get_ckpt_name(load_dir, tag)

Expand All @@ -1075,12 +1091,13 @@ def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logging.info('Loading checkpoint: {}'.format(load_path))
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)

self.load_module_state_dict(checkpoint['module'])
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
self.optimizer.load_state_dict(checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)

if self.lr_scheduler is not None:
if load_lr_scheduler_states and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
Expand All @@ -1089,6 +1106,7 @@ def _load_checkpoint(self, load_dir, tag, load_optimizer_states=True):
deepspeed_states = [
'module',
'optimizer',
'lr_scheduler',
'csr_tensor_module_names',
'skipped_steps',
'global_step'
Expand Down
170 changes: 153 additions & 17 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import json
import os
import numbers
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict

Expand Down Expand Up @@ -41,8 +42,6 @@ def compare_model_states(saved_model, loaded_model):


def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)

for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
Expand All @@ -52,13 +51,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1


def checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=True,
):
def compare_lr_scheduler_states(saved_model, loaded_model):
assert hasattr(saved_model, 'lr_scheduler')
assert hasattr(loaded_model, 'lr_scheduler')

saved_scheduler = saved_model.lr_scheduler
loaded_scheduler = loaded_model.lr_scheduler

assert hasattr(saved_scheduler, 'state_dict')
assert hasattr(loaded_scheduler, 'state_dict')

saved_sd = saved_scheduler.state_dict()
loaded_sd = loaded_scheduler.state_dict()

print(f"saved_sd = {saved_sd}")
print(f"loaded_sd = {loaded_sd}")

assert saved_sd.keys() == loaded_sd.keys()

for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
assert state0 == state1


def checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False):

ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
Expand All @@ -85,12 +106,16 @@ def checkpoint_correctness_verification(

loaded_model.load_checkpoint(save_folder,
save_tag,
load_optimizer_states=load_optimizer_states)
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

compare_model_states(trained_model, loaded_model)

if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
compare_model_states(trained_model, loaded_model)

if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)


def test_checkpoint_unfused_optimizer(tmpdir):
Expand Down Expand Up @@ -246,14 +271,125 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage):
model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
def _test_checkpoint_zero_no_optimizer(args,
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)

_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
_test_checkpoint_zero_no_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)


@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

_test_checkpoint_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=True)


@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

_test_checkpoint_no_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)

0 comments on commit 96b2224

Please sign in to comment.