Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] adapt solver with resnet #1583

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions colossalai/auto_parallel/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .graph_analysis import GraphAnalyser
from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *

__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser']
__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
9 changes: 7 additions & 2 deletions colossalai/auto_parallel/solver/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP'
'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP'
]

ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
# TODO: flatten should not be added into this group
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
Expand All @@ -20,3 +21,7 @@
]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]

INFINITY_COST = 1e13
1 change: 0 additions & 1 deletion colossalai/auto_parallel/solver/cost_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from typing import List
import math
from torch.fx.node import Node
Expand Down
61 changes: 25 additions & 36 deletions colossalai/auto_parallel/solver/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LiveVariable:
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
meta: Union[Any, List[Any]]
node: Node
is_inplace: bool


Expand Down Expand Up @@ -80,13 +80,13 @@ def graph(self) -> Graph:
"""
return self._graph

def liveness_analysis(self) -> OrderedDict[int, LiveStage]:
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_dict = ODict()
liveness_list = []

# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
Expand All @@ -97,25 +97,6 @@ def liveness_analysis(self) -> OrderedDict[int, LiveStage]:
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()

def _add_param_or_buf(node, tensor_type):
module = get_node_module(node)

if tensor_type == 'param':
iterator = module.named_parameters()
elif tensor_type == 'buffer':
iterator = module.named_buffers()
else:
raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}")

for name, tensor in iterator:
tensor_name = f'{node.name}.{name}'

if not checked_variables.exists(tensor_name):
live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False)
unique_live_vars.append(live_tensor)
checked_variables.append(live_tensor)
all_live_variables.append(live_tensor)

for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
Expand All @@ -135,26 +116,19 @@ def _add_param_or_buf(node, tensor_type):

# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)

# add the model parameters
if node.op == 'call_module':
_add_param_or_buf(node, tensor_type='param')
_add_param_or_buf(node, tensor_type='buffer')

# add this output variable to the checked list
checked_variables.append(live_var)

# check if any input is not checked yet
for arg in node.args:
arg_name = str(arg)
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
meta = getattr(node, '_meta_data', None)
live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False)
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
Expand All @@ -167,8 +141,23 @@ def _add_param_or_buf(node, tensor_type):
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
liveness_dict[idx] = stage
return liveness_dict
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)

return liveness_list

def get_alias_set(self):
pass
6 changes: 6 additions & 0 deletions colossalai/auto_parallel/solver/op_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler

__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler']