Skip to content

Commit

Permalink
[autoparallel] added new strategy constructor template
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 27, 2022
1 parent 30e50c8 commit 9ebccfe
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 5 deletions.
3 changes: 2 additions & 1 deletion colossalai/auto_parallel/solver/op_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler

__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler'
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
]
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/solver/op_handler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def wrapper(func):
return wrapper

def get(self, source):
assert source in self.store
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target

Expand Down
7 changes: 4 additions & 3 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
OUTPUT = 2
INPUT = 0
ARG = 1
PARAM = 2
OUTPUT = 3


@dataclass
Expand Down
95 changes: 95 additions & 0 deletions colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
Expand All @@ -16,6 +17,8 @@
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins

__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2']


class StrategiesConstructor:
"""
Expand Down Expand Up @@ -49,6 +52,10 @@ def remove_duplicated_strategy(self, strategies_vector):
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)

if remove_list:
print(f'remove list: {[item.name for item in remove_list]}')
print(f'keep list: {name_checklist}')
for strategy in remove_list:
strategies_vector.remove(strategy)

Expand Down Expand Up @@ -394,3 +401,91 @@ def build_strategies_and_cost(self):
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector


class StrategiesConstructor_V2:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""

def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options

def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)

def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
"""
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = 0
for check_node in strategies_vector.predecessor_nodes:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
# input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# TODO: implement placeholder node handler
pass

# get_attr node
elif node.op == 'get_attr':
# TODO: implement getattr node handler
pass

# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler.register_strategy()

# output node
elif node.op == 'output':
# TODO: implement output node handler
pass

self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

0 comments on commit 9ebccfe

Please sign in to comment.