Skip to content

Commit

Permalink
[hotfix]: add pp sanity check and fix mbs arg (#5268)
Browse files Browse the repository at this point in the history
* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check
  • Loading branch information
CWHer authored and ver217 committed Feb 7, 2024
1 parent fd05db8 commit 5bf7d66
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 7 deletions.
4 changes: 4 additions & 0 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch

assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"

if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches

assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"

if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
Expand Down
16 changes: 11 additions & 5 deletions examples/language/llama2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def main():
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1)
parser.add_argument("--zero", type=int, default=0)
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
args = parser.parse_args()

colossalai.launch_from_torch({})
Expand All @@ -98,7 +98,13 @@ def empty_init():
extra_dp_size=args.extra_dp,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp)
plugin = GeminiPlugin(
placement_policy="auto",
precision="bf16",
warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
)
elif args.plugin == "fsdp":
if use_empty_init:
plugin = TorchFSDPPlugin(
Expand Down Expand Up @@ -137,7 +143,7 @@ def empty_init():
zero_stage=args.zero,
num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
microbatch_size=args.mbs,
precision="bf16",
)
elif args.plugin == "3d_cpu":
Expand All @@ -147,7 +153,7 @@ def empty_init():
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
)
Expand Down
2 changes: 1 addition & 1 deletion examples/language/llama2/scripts/benchmark_70B/3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ cd ../..

export OMP_NUM_THREADS=8

colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 4
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py -c 70b -p 3d -g -x -b 8 --tp 4 --pp 2 --mbs 1
2 changes: 1 addition & 1 deletion tests/test_pipeline/test_schedule/test_oneF_oneB.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run_dist(


@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("num_microbatch", [4, 6])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
Expand Down

0 comments on commit 5bf7d66

Please sign in to comment.