Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dist tests: Checkpointing #2202

Merged
merged 24 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e53a8c9
Refactor dist tests: Checkpointing
tjruwase Aug 9, 2022
30177d1
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 9, 2022
477a7b8
Remove local functions
tjruwase Aug 9, 2022
66cb4c5
ds.init() with config_dict
tjruwase Aug 9, 2022
274fbb6
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 9, 2022
ce38962
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 9, 2022
eb2bee7
Hardcode to simplify
tjruwase Aug 9, 2022
a3b0be7
Merge branch 'olruwase/refactor_dist_ci' of github.com:microsoft/Deep…
tjruwase Aug 9, 2022
8a64bd9
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 10, 2022
2947970
Merge branch 'master' into olruwase/refactor_dist_ci
mrwyattii Aug 11, 2022
4af09fe
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 12, 2022
4aab5d1
Format fixes
tjruwase Aug 12, 2022
45aafd4
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Aug 12, 2022
8dba9db
Try avoiding race
tjruwase Aug 12, 2022
d2da10e
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 15, 2022
3be389b
Barrier for checkpoint saves
tjruwase Aug 15, 2022
94a8d41
Merge branch 'olruwase/refactor_dist_ci' of github.com:microsoft/Deep…
tjruwase Aug 15, 2022
968bd75
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 15, 2022
ecb8abb
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 17, 2022
b139325
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 17, 2022
b80cacf
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 18, 2022
7876e5a
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 19, 2022
4a00f62
Merge branch 'master' into olruwase/refactor_dist_ci
tjruwase Aug 22, 2022
808d428
Merge branch 'master' into olruwase/refactor_dist_ci
mrwyattii Aug 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
202 changes: 202 additions & 0 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import os
import torch
import numbers

import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3

from tests.unit.simple_model import *


def compare_deepspeed_states(saved_model, loaded_model):
# These are compared in more depth in other places
assert hasattr(loaded_model, 'module')

assert saved_model.sparse_tensor_module_names == loaded_model.sparse_tensor_module_names
assert saved_model.skipped_steps == loaded_model.skipped_steps
assert saved_model.global_steps == loaded_model.global_steps


def compare_model_states(saved_model,
loaded_model,
compare_optimizer=True,
load_module_only=False):
if not load_module_only:
compare_deepspeed_states(saved_model, loaded_model)

for p0, p1 in zip(saved_model.module.named_parameters(), loaded_model.module.named_parameters()):
np0, p0 = p0
np1, p1 = p1
if 'deepspeed_moe.gate.wg' in np0:
# these params are converted to float at runtime, cast to half for comparison
p1 = p1.half()
p0 = p0.half()
assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
try:
assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}"
except RuntimeError as err:
print(f"FP16 model state {p0} is not equal to {p1}, names:{np0}, {np1}")
raise err

if not compare_optimizer:
return

if DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer,
DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"

elif isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"

elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"

elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
for p0, p1 in zip(params0, params1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
pass
else:
assert False, f'Unexpected Optimizer Type: {saved_model.optimizer}'


def compare_state_dicts(state0, state1, expected_mismatch_keys=[]):
for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()):
assert k0 == k1, f'failure due to key mismatch {k0} != {k1}'
if k0 in expected_mismatch_keys:
continue
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
assert torch.equal(s0.to('cpu'), s1.to('cpu'))
else:
assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}'


def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer
loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer

for state0, state1 in zip(saved_optimizer.state.values(),
loaded_optimizer.state.values()):
compare_state_dicts(state0, state1)


def compare_lr_scheduler_states(saved_model, loaded_model):
assert hasattr(saved_model, 'lr_scheduler')
assert hasattr(loaded_model, 'lr_scheduler')

saved_scheduler = saved_model.lr_scheduler
loaded_scheduler = loaded_model.lr_scheduler

assert hasattr(saved_scheduler, 'state_dict')
assert hasattr(loaded_scheduler, 'state_dict')

saved_sd = saved_scheduler.state_dict()
loaded_sd = loaded_scheduler.state_dict()

print(f"saved_sd = {saved_sd}")
print(f"loaded_sd = {loaded_sd}")

assert saved_sd.keys() == loaded_sd.keys()

for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
assert state0 == state1


def create_deepspeed_model(args, model, base_optimizer):
if base_optimizer is None:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
else:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
optimizer=base_optimizer)

return ds_model


def checkpoint_correctness_verification(args,
models,
hidden_dim,
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False,
fp16=True,
train_batch=False,
base_optimizers=[None,
None],
empty_tag=False,
seq_dataloader=False,
load_module_only=False):
dtype = torch.half if fp16 else torch.float32
ds_model = create_deepspeed_model(args=args,
model=models[0],
base_optimizer=base_optimizers[0])

if seq_dataloader:
data_loader = sequence_dataloader(model=ds_model,
total_samples=50,
hidden_dim=hidden_dim,
device=ds_model.device,
dtype=dtype)
else:
data_loader = random_dataloader(model=ds_model,
total_samples=50,
hidden_dim=hidden_dim,
device=ds_model.device,
dtype=dtype)

if train_batch:
ds_model.set_dataloader(data_loader)
for _, batch in enumerate(data_loader):
loss = ds_model.train_batch()
else:
for _, batch in enumerate(data_loader):
loss = ds_model(batch[0], batch[1])
ds_model.backward(loss)
ds_model.step()

trained_model = ds_model

save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = None if empty_tag else '1'

trained_model.save_checkpoint(save_folder, tag=save_tag)

dist.barrier()

loaded_model = create_deepspeed_model(args=args,
model=models[1],
base_optimizer=base_optimizers[1])
assert list(trained_model.parameters())[0].dtype == list(
loaded_model.parameters())[0].dtype

loaded_model.load_checkpoint(save_folder,
tag=save_tag,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only)

compare_model_states(trained_model,
loaded_model,
compare_optimizer=load_optimizer_states,
load_module_only=load_module_only)

if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)

if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
62 changes: 62 additions & 0 deletions tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import deepspeed

from tests.unit.common import DistributedTest
from tests.unit.simple_model import *

from tests.unit.checkpoint.common import checkpoint_correctness_verification


class TestLatestCheckpoint(DistributedTest):
world_size = 1

def test_existing_latest(self, tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]

def _helper(args, models):
checkpoint_correctness_verification(args,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True)

_helper(args, models)

def test_missing_latest(self, tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)

model = SimpleModel(hidden_dim)

def _helper(args, model):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# should be no-op, since latest doesn't exist
model.load_checkpoint(tmpdir)

_helper(args=args, model=model)