Skip to content

Commit

Permalink
[autoparallel] added compute resharding costs for node handler (#1662)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 28, 2022
1 parent 9ec401a commit 50f16a2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
45 changes: 43 additions & 2 deletions colossalai/auto_parallel/solver/op_handler/node_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator_V2


Expand All @@ -28,13 +29,53 @@ def __init__(
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector

def register_strategy(self) -> StrategiesVector:
def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None:
"""
Compute the resharding costs and save the costs in the ShardingStrategy object.
"""
# TODO: test this function when other handlers are ready
resharding_costs = {}
shape_consistency_manager = ShapeConsistencyManager()
for node in self.predecessor_node:
node_name = str(node)

# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]

# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]

# create data structrure to store costs
if op_data not in resharding_costs:
resharding_costs[op_data] = {}

# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
fwd_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, current_sharding_spec)
bwd_cost = shape_consistency_manager.shape_consistency(current_sharding_spec, prev_sharding_spec)
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=fwd_cost + bwd_cost)
resharding_costs[op_data][prev_sharding_spec] = resharding_cost
strategy.resharding_costs = resharding_costs

def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
for generator in strategy_generators:
strategies = generator.generate()

# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
strategies = list(map(self.update_resharding_cost, strategies))
self.strategies_vector.extend(strategies)

strategies_vector = map(self.post_process, self.strategies_vector)
Expand Down
13 changes: 13 additions & 0 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ShardingStrategy_V2:
memory_cost: TrainCycleItem = None
input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, CommSpec] = None
resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None

@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
Expand All @@ -153,6 +154,18 @@ def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs

def get_op_data_by_name(self, name: str):
for op_data in self.sharding_specs.keys():
if op_data.name == name:
return op_data
raise KeyError(f"Could not find the OperationData with name {name}")

def get_sharding_spec_by_name(self, name: str):
for op_data, sharding_spec in self.sharding_specs.items():
if op_data.name == name:
return sharding_spec
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")


class StrategiesVector(list):
'''
Expand Down

0 comments on commit 50f16a2

Please sign in to comment.