Skip to content

Commit

Permalink
[autoparallel] recover the merged node strategy index (#1613)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangLiu0306 committed Sep 23, 2022
1 parent d6b01fe commit bf77d3a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
22 changes: 20 additions & 2 deletions colossalai/auto_parallel/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def __init__(self,
# The last objective value of the best ILP solution.
self.last_objective = None

def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.graph.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break

def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
Expand Down Expand Up @@ -411,13 +428,14 @@ def get_non_zero_index(binary_vector):
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")

self.last_s_val = s_val
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective

if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")

return s_val, e_val, objective, status
return self.last_s_val, e_val, self.last_objective, status

def call_solver_serialized_args(self):
"""
Expand Down
17 changes: 6 additions & 11 deletions tests/test_auto_parallel/test_solver_with_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,27 @@ def test_cost_graph():
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# print(len(liveness_dict[0].unique_live_vars))
# assert False
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)

ret = solver.call_solver_serialized_args()
print(ret)
strategies_list = list(ret[0])
print(strategies_list)
print(ret[0])
solver._recover_merged_node_strategy()
print(solver.last_s_val)
strategies_list = solver.last_s_val

computation_cost = 0
communication_cost = 0
communication_cost_bn = 0
memory_cost = 0
for index, node in enumerate(graph.nodes):
if node.op == 'call_module':
submod = node.graph.owning_module.get_submodule(node.target)
if type(submod) in ELEMENTWISE_MODULE_OP:
input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec
print(node.name, input_spec)
continue
if type(submod) in BATCHNORM_MODULE_OP:
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
print(node.name, node.strategies_vector[strategies_list[index]].name)
Expand Down

0 comments on commit bf77d3a

Please sign in to comment.