Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] added new strategy constructor template #1661

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/auto_parallel/solver/op_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler

__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler'
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
]
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/solver/op_handler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def wrapper(func):
return wrapper

def get(self, source):
assert source in self.store
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target

Expand Down
7 changes: 4 additions & 3 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
OUTPUT = 2
INPUT = 0
ARG = 1
PARAM = 2
OUTPUT = 3


@dataclass
Expand Down
88 changes: 88 additions & 0 deletions colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
Expand All @@ -16,6 +17,8 @@
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins

__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2']


class StrategiesConstructor:
"""
Expand Down Expand Up @@ -49,6 +52,7 @@ def remove_duplicated_strategy(self, strategies_vector):
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)

for strategy in remove_list:
strategies_vector.remove(strategy)

Expand Down Expand Up @@ -394,3 +398,87 @@ def build_strategies_and_cost(self):
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector


class StrategiesConstructor_V2:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.

Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""

def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options

def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)

def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
"""
for node in self.nodes:
strategies_vector = StrategiesVector(node)

# placeholder node
if node.op == 'placeholder':
# TODO: implement placeholder node handler
pass

# get_attr node
elif node.op == 'get_attr':
# TODO: implement getattr node handler
pass

# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# output node
elif node.op == 'output':
# TODO: implement output node handler
pass

self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector