Skip to content

Commit

Permalink
Fix Pipeline Parallel resize unit test (#2833)
Browse files Browse the repository at this point in the history
* fix overlapping checkpoint names in unit tests

* remove running cpu-only on master merge
  • Loading branch information
mrwyattii committed Feb 15, 2023
1 parent 639aa7b commit cc1054d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
1 change: 0 additions & 1 deletion .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ name: nv-torch-latest-cpu
on:
push:
branches:
- 'master'
- 'staging**'
paths-ignore:
- 'docs/**'
Expand Down
72 changes: 62 additions & 10 deletions tests/unit/model_parallelism/test_configurable_parallel_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,19 @@ def test_pp_basic(self, inputs, tmpdir):
assert torch.allclose(b, t, atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"


# Fixture for defining the checkpoint path since all tests in
# TestConfigurableResizePP will use the same tmpdir
@pytest.fixture
def checkpoint_tag(mp_size, pp_size, mp_resize, pp_resize):
return f"{mp_size}-{pp_size}-{mp_resize}-{pp_resize}"


# Base class for creating / saving model output for baseline models. This is
# not meant to be used directly as a fixture to any classes
class _baseline(DistributedFixture):
world_size = None

def run(self, inputs, class_tmpdir, mp_size, pp_size):
def run(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size):
assert int(os.environ["WORLD_SIZE"]) == (pp_size * mp_size), "world size does not match provided pp_size and mp_size"
args_defaults = {
'num_layers': 8,
Expand Down Expand Up @@ -162,12 +169,14 @@ def run(self, inputs, class_tmpdir, mp_size, pp_size):
assert len(baseline) == 1
assert len(baseline[0]) == 1
assert torch.is_tensor(baseline[0][0])
save_path = os.path.join(class_tmpdir, "output.pt")
save_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
torch.save(baseline[0][0].cpu(), save_path)

state_dict = {}
state_dict['checkpoint_version'] = get_megatron_version()
model.save_checkpoint(class_tmpdir, client_state=state_dict)
model.save_checkpoint(class_tmpdir,
tag=checkpoint_tag,
client_state=state_dict)


# This may look odd, but there is a limitation with DistributedFixture that
Expand All @@ -186,7 +195,14 @@ class baseline_ws4(_baseline):


class TestConfigurableResizePP(ConfigurablePP):
def _test(self, inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize):
def _test(self,
inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize):
args_defaults = {
'num_layers': 8,
'hidden_size': 128,
Expand All @@ -204,6 +220,7 @@ def _test(self, inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize):

with torch.no_grad():
model.load_checkpoint(class_tmpdir,
tag=checkpoint_tag,
load_optimizer_states=False,
load_lr_scheduler_states=False)
inputs = [x.cuda() for x in inputs]
Expand All @@ -223,7 +240,7 @@ def _test(self, inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize):
assert len(test[0]) == 1
assert torch.is_tensor(test[0][0])
test = test[0][0].cpu()
load_path = os.path.join(class_tmpdir, "output.pt")
load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
baseline = torch.load(load_path)
assert torch.allclose(baseline, test, atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"

Expand All @@ -233,48 +250,76 @@ def _test(self, inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize):
def test_world_size_2to1(self,
inputs,
class_tmpdir,
checkpoint_tag,
baseline_ws2,
mp_size,
pp_size,
mp_resize,
pp_resize):
self._test(inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize)
self._test(inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize)

@pytest.mark.world_size(1)
@pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 1, 1)])
def test_world_size_4to1(self,
inputs,
class_tmpdir,
checkpoint_tag,
baseline_ws4,
mp_size,
pp_size,
mp_resize,
pp_resize):
self._test(inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize)
self._test(inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize)

@pytest.mark.world_size(2)
@pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 2, 1)])
def test_world_size_4to2(self,
inputs,
class_tmpdir,
checkpoint_tag,
baseline_ws4,
mp_size,
pp_size,
mp_resize,
pp_resize):
self._test(inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize)
self._test(inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize)

@pytest.mark.world_size(4)
@pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 1, 2, 2)])
def test_world_size_1to4(self,
inputs,
class_tmpdir,
checkpoint_tag,
baseline_ws1,
mp_size,
pp_size,
mp_resize,
pp_resize):
self._test(inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize)
self._test(inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize)

@pytest.mark.world_size(4)
@pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize",
Expand All @@ -289,9 +334,16 @@ def test_world_size_1to4(self,
def test_world_size_2to4(self,
inputs,
class_tmpdir,
checkpoint_tag,
baseline_ws2,
mp_size,
pp_size,
mp_resize,
pp_resize):
self._test(inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize)
self._test(inputs,
class_tmpdir,
checkpoint_tag,
mp_size,
pp_size,
mp_resize,
pp_resize)

0 comments on commit cc1054d

Please sign in to comment.