Skip to content

Commit

Permalink
[autoparallel] remove no strategy nodes (#1652)
Browse files Browse the repository at this point in the history
* [autoparallel] remove no strategy nodes

* fix none object iteration issue
  • Loading branch information
YuliangLiu0306 committed Sep 29, 2022
1 parent 50f16a2 commit c27e701
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 36 deletions.
19 changes: 17 additions & 2 deletions colossalai/auto_parallel/solver/cost_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
import math
from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST


class CostGraph:
Expand All @@ -19,6 +20,7 @@ class CostGraph:

def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
Expand All @@ -27,6 +29,15 @@ def __init__(self, leaf_strategies, simplify=True):
self.simplify = simplify
self._build_cost_graph()

def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)

def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
Expand All @@ -39,6 +50,8 @@ def _build_cost_graph(self):
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
Expand All @@ -49,6 +62,8 @@ def _build_cost_graph(self):
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')

if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
Expand Down Expand Up @@ -83,11 +98,11 @@ def merge_node(self, src_node, dst_node):
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = math.inf
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost < min_cost:
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
Expand Down
2 changes: 2 additions & 0 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def __init__(self, node: Node):
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())

def check_merge(self):
Expand Down
9 changes: 6 additions & 3 deletions colossalai/auto_parallel/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self,
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
Expand All @@ -67,7 +67,7 @@ def _recover_merged_node_strategy(self):
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.graph.nodes):
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
Expand Down Expand Up @@ -297,7 +297,8 @@ def get_non_zero_index(binary_vector):
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])

for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
Expand All @@ -317,12 +318,14 @@ def get_non_zero_index(binary_vector):
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])

#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])

prob += obj
Expand Down
90 changes: 59 additions & 31 deletions colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,21 @@ def build_strategies_and_cost(self):
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()

# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass

elif input_nodes_len == 2:
# one of x or y is type of scalar
pass

else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()

# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
Expand All @@ -222,9 +237,8 @@ def build_strategies_and_cost(self):

# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
if isinstance(node._meta_data, torch.Tensor):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()

# bcast op
elif target in BCAST_FUNC_OP:
Expand Down Expand Up @@ -291,32 +305,34 @@ def build_strategies_and_cost(self):
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
if isinstance(input_tensor_node, torch.Tensor):
for strategy in input_tensor_node.strategies_vector:
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(
name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)

# torch.arange function
elif target == torch.arange:
Expand All @@ -334,8 +350,7 @@ def build_strategies_and_cost(self):
strategies_vector.append(sharding_strategy)

# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
torch.nn.functional.softmax, torch.pow, torch.tanh):
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
Expand All @@ -344,7 +359,7 @@ def build_strategies_and_cost(self):
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size, torch.Tensor.contiguous):
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
Expand Down Expand Up @@ -400,6 +415,18 @@ def build_strategies_and_cost(self):
self.strategy_map[node] = strategies_vector


# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)


class StrategiesConstructor_V2:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Expand Down Expand Up @@ -482,3 +509,4 @@ def build_strategies_and_cost(self):
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

0 comments on commit c27e701

Please sign in to comment.