Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] added compute resharding costs for node handler #1662

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 43 additions & 2 deletions colossalai/auto_parallel/solver/op_handler/node_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator_V2


Expand All @@ -28,13 +29,53 @@ def __init__(
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector

def register_strategy(self) -> StrategiesVector:
def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None:
"""
Compute the resharding costs and save the costs in the ShardingStrategy object.
"""
# TODO: test this function when other handlers are ready
resharding_costs = {}
shape_consistency_manager = ShapeConsistencyManager()
for node in self.predecessor_node:
node_name = str(node)

# get the sharding specs for this node generated
# in its own node handler
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]

# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]

# create data structrure to store costs
if op_data not in resharding_costs:
resharding_costs[op_data] = {}

# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for prev_sharding_spec in prev_sharding_specs:
fwd_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, current_sharding_spec)
bwd_cost = shape_consistency_manager.shape_consistency(current_sharding_spec, prev_sharding_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShapeConsistencyManager will distinguish the mission type(training or inference) by a forward_only label, and returns a option related cost. Therefore, we don't have to compute fwd cost and bwd cost separately here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I notice that ShapeConsistencyManager only returns a total cost. How to get the fwd and bwd cost then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean I want to store this cost as a TrainCycleItem so that solver can choose whether to consider the bwd cost.

resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=fwd_cost + bwd_cost)
resharding_costs[op_data][prev_sharding_spec] = resharding_cost
strategy.resharding_costs = resharding_costs

def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
for generator in strategy_generators:
strategies = generator.generate()

# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
strategies = list(map(self.update_resharding_cost, strategies))
self.strategies_vector.extend(strategies)

strategies_vector = map(self.post_process, self.strategies_vector)
Expand Down
13 changes: 13 additions & 0 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ShardingStrategy_V2:
memory_cost: TrainCycleItem = None
input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, CommSpec] = None
resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None

@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
Expand All @@ -152,6 +153,18 @@ def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs

def get_op_data_by_name(self, name: str):
for op_data in self.sharding_specs.keys():
if op_data.name == name:
return op_data
raise KeyError(f"Could not find the OperationData with name {name}")

def get_sharding_spec_by_name(self, name: str):
for op_data, sharding_spec in self.sharding_specs.items():
if op_data.name == name:
return sharding_spec
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")


class StrategiesVector(list):
'''
Expand Down