-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pipeline/chimera] reconstruct PipelineBase and Worker to support mor…
…e feasible custom schedule | finish Chimera (#1595) * [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera
- Loading branch information
1 parent
c9e8ce6
commit edc9e41
Showing
8 changed files
with
614 additions
and
163 deletions.
There are no files selected for viewing
347 changes: 188 additions & 159 deletions
347
colossalai/pipeline/rpc/PipelineBase.py → colossalai/pipeline/rpc/_pipeline_base.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
from typing import List, Callable, Dict | ||
|
||
import torch.nn as nn | ||
from torch.futures import Future | ||
from torch._C._distributed_rpc import PyRRef | ||
|
||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase | ||
|
||
# Implementation of different Pipeline schedule | ||
# <strategy>Worker defines the worker for each stage | ||
# <strategy>PipelineEngine is the class for use | ||
|
||
|
||
class FillDrainWorker(WorkerBase): | ||
|
||
def _get_work_item_key(self) -> UniqueKey: | ||
# execute backward first (if backward phase in work_list) | ||
num_microbatches = self.num_microbatches | ||
|
||
if self.forward_times < num_microbatches: | ||
target_phase = Phase.FORWARD | ||
target_microbatch_id = self.forward_times | ||
else: | ||
target_phase = Phase.BACKWARD | ||
target_microbatch_id = self.backward_times | ||
|
||
target_key = UniqueKey(target_microbatch_id, target_phase) | ||
|
||
with self.work_list_condition_lock: | ||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) | ||
|
||
return target_key | ||
|
||
|
||
class FillDrainPipelineEngine(PipelineEngineBase): | ||
|
||
def __init__(self, | ||
module_partitions: List[nn.Module], | ||
stage_num: int, | ||
num_microbatches: int, | ||
device: str, | ||
chunk: int = 1, | ||
criterion: Callable = None, | ||
metric: Callable = None, | ||
checkpoint: bool = False) -> None: | ||
|
||
if chunk > 1: | ||
assert num_microbatches % stage_num == 0, \ | ||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" | ||
use_1F1B = False | ||
|
||
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, | ||
criterion, metric, checkpoint) | ||
|
||
|
||
class OneFOneBWorker(WorkerBase): | ||
|
||
def _get_work_item_key(self) -> UniqueKey: | ||
# execute backward first (if backward phase in work_list) | ||
pp_rank = self.pp_rank | ||
actual_stage_num = self.actual_stage_num | ||
num_microbatches = self.num_microbatches | ||
is_last_stage = pp_rank == actual_stage_num - 1 | ||
|
||
if self.outstanding <= self.outstanding_range[0]: | ||
target_phase = Phase.FORWARD | ||
target_microbatch_id = self.forward_times | ||
elif self.outstanding >= self.outstanding_range[1]: | ||
target_phase = Phase.BACKWARD | ||
target_microbatch_id = self.backward_times | ||
else: | ||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]") | ||
|
||
target_key = UniqueKey(target_microbatch_id, target_phase) | ||
|
||
# change outstanding_range at: | ||
# 1. forward times reach actual_stage_num, this is the end of continuous forward | ||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode | ||
if not is_last_stage and \ | ||
target_key.phase == Phase.FORWARD: | ||
if target_key.microbatch_id == actual_stage_num - 1: | ||
outstanding_min = actual_stage_num - pp_rank - 1 | ||
outstanding_max = actual_stage_num - pp_rank | ||
self.outstanding_range = (outstanding_min, outstanding_max) | ||
elif target_key.microbatch_id == num_microbatches - 1: | ||
self.outstanding_range = (0, 0) | ||
|
||
with self.work_list_condition_lock: | ||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) | ||
|
||
return target_key | ||
|
||
|
||
class OneFOneBPipelineEngine(PipelineEngineBase): | ||
|
||
def __init__(self, | ||
module_partitions: List[nn.Module], | ||
stage_num: int, | ||
num_microbatches: int, | ||
device: str, | ||
chunk: int = 1, | ||
criterion: Callable = None, | ||
metric: Callable = None, | ||
checkpoint: bool = False) -> None: | ||
|
||
if chunk > 1: | ||
assert num_microbatches % stage_num == 0, \ | ||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" | ||
use_1F1B = True | ||
|
||
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, | ||
criterion, metric, checkpoint) | ||
|
||
|
||
class ChimeraWorker(WorkerBase): | ||
|
||
def _get_producer_consumer(self) -> None: | ||
rank = self.pp_rank | ||
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num | ||
max_pp_rank = min_pp_rank + self.actual_stage_num - 1 | ||
|
||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" | ||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" | ||
|
||
# should be aranged in order, the order of the input of current forward | ||
self.producer_stage_ids = [] | ||
self.consumer_stage_ids = [] | ||
|
||
# Just for demo | ||
prev_rank = rank - 1 | ||
next_rank = rank + 1 | ||
if prev_rank >= min_pp_rank: | ||
self.producer_stage_ids.append(prev_rank) | ||
if next_rank <= max_pp_rank: | ||
self.consumer_stage_ids.append(next_rank) | ||
|
||
def _get_work_item_key(self) -> UniqueKey: | ||
pp_rank = self.pp_rank | ||
stage_num = self.actual_stage_num | ||
real_microbatch_num = self.num_microbatches // 2 | ||
|
||
if self.forward_times < real_microbatch_num: | ||
if (pp_rank + 1) % stage_num == 0: # last rank | ||
forward_blocks = self.forward_times // (self.num_microbatches // stage_num) | ||
if forward_blocks > self.backward_times: | ||
target_phase = Phase.BACKWARD | ||
target_microbatch_id = self.backward_times | ||
else: | ||
target_phase = Phase.FORWARD | ||
target_microbatch_id = self.forward_times | ||
else: # others | ||
target_phase = Phase.FORWARD | ||
target_microbatch_id = self.forward_times | ||
else: | ||
target_phase = Phase.BACKWARD | ||
target_microbatch_id = self.backward_times | ||
|
||
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) | ||
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) | ||
real_target_microbatch_id = target_microbatch_id * 2 | ||
if pp_rank >= stage_num: | ||
real_target_microbatch_id += 1 | ||
target_key = UniqueKey(real_target_microbatch_id, target_phase) | ||
|
||
with self.work_list_condition_lock: | ||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) | ||
|
||
return target_key | ||
|
||
def is_first_stage(self): | ||
return (self.pp_rank % self.actual_stage_num) == 0 | ||
|
||
def is_last_stage(self): | ||
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 | ||
|
||
|
||
class ChimeraPipelineEngine(PipelineEngineBase): | ||
|
||
def __init__(self, | ||
module_partitions, | ||
stage_num, | ||
num_microbatches, | ||
device: str, | ||
criterion: Callable = None, | ||
metric: Callable = None, | ||
checkpoint: bool = False) -> None: | ||
|
||
assert num_microbatches % stage_num == 0, \ | ||
"In Chimera, num_microbatches must be the multiply of stage_num!" | ||
use_1F1B = False | ||
chunk = 1 | ||
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, | ||
criterion, metric, checkpoint) | ||
|
||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], | ||
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]): | ||
pass | ||
|
||
def _create_pp_rank_to_rpc_worker_id(self) -> None: | ||
stage_num = self.stage_num | ||
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2) | ||
for pp_rank in range(stage_num): | ||
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank | ||
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1 | ||
|
||
def _create_pp_rank_to_module_partition_id(self) -> None: | ||
stage_num = self.stage_num | ||
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2) | ||
for pp_rank in range(stage_num): | ||
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank | ||
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank | ||
|
||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: | ||
num_microbatches = self.num_microbatches | ||
stage_num = self.stage_num | ||
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks} | ||
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks} | ||
# merge up and down | ||
return {**up_ret_future, **down_ret_future} | ||
|
||
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool): | ||
# offset is 0 for all the ranks in up pipeline | ||
# offset is stage_num for all the ranks in down pipeline | ||
offset = (microbatch_id % 2) * self.stage_num | ||
for pp_rank in input_pp_ranks: | ||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] | ||
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) | ||
|
||
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels): | ||
# offset is 0 for all the ranks in up pipeline | ||
# offset is stage_num for all the ranks in down pipeline | ||
offset = (microbatch_id % 2) * self.stage_num | ||
for pp_rank in output_pp_ranks: | ||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] | ||
worker_rref.remote().set_labels(microbatch_id, microlabels) | ||
|
||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): | ||
key = UniqueKey(microbatch_id, Phase.FORWARD) | ||
offset = (microbatch_id % 2) * self.stage_num | ||
for pp_rank in output_pp_ranks: | ||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset] | ||
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key) | ||
|
||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): | ||
stage_num = self.stage_num | ||
num_microbatches = self.num_microbatches | ||
if not forward_only: | ||
for pp_rank in input_pp_ranks: | ||
up_last_microbatch_id = num_microbatches - 2 | ||
down_last_microbatch_id = num_microbatches - 1 | ||
|
||
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank] | ||
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num] | ||
|
||
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) | ||
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) | ||
|
||
up_worker_rref.rpc_sync().get_output_by_key(up_key) | ||
down_worker_rref.rpc_sync().get_output_by_key(down_key) | ||
|
||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]): | ||
"""Logic of collection of forward in Chimera. | ||
Currently, only one input one output model is supported | ||
""" | ||
stage_num = self.stage_num | ||
forward_result = [] | ||
for pp_rank in output_pp_ranks: | ||
worker_forward_result = [None] * self.num_microbatches | ||
for microbatch_id in range(self.num_microbatches): | ||
offset = (microbatch_id % 2) * stage_num | ||
ret = ret_future[pp_rank + offset][microbatch_id].wait() | ||
worker_forward_result[microbatch_id] = ret | ||
|
||
worker_forward_result = list(zip(*worker_forward_result)) | ||
forward_result.extend(worker_forward_result) | ||
|
||
return forward_result |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine | ||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel | ||
|
||
|
||
def run_master(args): | ||
torch.manual_seed(100) | ||
|
||
epoch = args.epoch | ||
device = args.device | ||
stage_num = 4 | ||
chunk = 1 | ||
num_microbatches = 4 | ||
actual_stage_num = 4 | ||
use_checkpoint = False | ||
|
||
sample_num = 1024 | ||
feat_num = 10 | ||
h = 10 | ||
batch_size = 1024 | ||
|
||
assert sample_num % batch_size == 0 | ||
|
||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] | ||
engine = ChimeraPipelineEngine(module_partitions=module_partitions, | ||
stage_num=stage_num, | ||
num_microbatches=num_microbatches, | ||
device=device, | ||
checkpoint=use_checkpoint) | ||
|
||
input_sample = torch.randn((sample_num, feat_num), device=device) | ||
|
||
for _ in range(epoch): | ||
_ = engine.forward_backward(input_sample, forward_only=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
args.world_size = 4 | ||
args.num_microbatches = 4 | ||
rpc_run(args, run_master) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.