Skip to content

Commit

Permalink
[autoparallel] implemented all matmul strategy generator (#1650)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 27, 2022
1 parent 03978aa commit 30e50c8
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 76 deletions.
20 changes: 18 additions & 2 deletions colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,16 @@ def post_process(self, strategy: ShardingStrategy_V2):
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict

# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)

if first_dim_partition:
dim_partition_dict[0] = first_dim_partition

if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition

# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
Expand Down Expand Up @@ -111,8 +119,16 @@ def post_process(self, strategy: ShardingStrategy_V2):
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict

# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)

if first_dim_partition:
dim_partition_dict[0] = first_dim_partition

if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition

# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
Expand Down
6 changes: 3 additions & 3 deletions colossalai/auto_parallel/solver/op_handler/node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def register_strategy(self) -> StrategiesVector:
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators:
strategies = generator.generate(operand_mapping)
strategies = generator.generate()
self.strategies_vector.extend(strategies)

self.strategies_vector = map(self.post_process, self.strategies_vector)
strategies_vector = map(self.post_process, self.strategies_vector)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector

def post_process(self, strategy: ShardingStrategy_V2):
Expand Down
6 changes: 6 additions & 0 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def __post_init__(self):
if self.logical_shape is None:
self.logical_shape = self.data.shape

def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'

def __hash__(self) -> int:
return hash(f'{self.name}-{self.type}')


@dataclass
class TrainCycleItem:
Expand Down

0 comments on commit 30e50c8

Please sign in to comment.