Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] remove no strategy nodes #1652

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion colossalai/auto_parallel/solver/cost_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
import math
from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST


class CostGraph:
Expand All @@ -19,6 +20,7 @@ class CostGraph:

def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
Expand All @@ -27,6 +29,15 @@ def __init__(self, leaf_strategies, simplify=True):
self.simplify = simplify
self._build_cost_graph()

def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, None)
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
YuliangLiu0306 marked this conversation as resolved.
Show resolved Hide resolved

def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
Expand All @@ -39,6 +50,8 @@ def _build_cost_graph(self):
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
Expand All @@ -49,6 +62,8 @@ def _build_cost_graph(self):
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')

if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
Expand Down Expand Up @@ -83,7 +98,7 @@ def merge_node(self, src_node, dst_node):
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = math.inf
min_cost = INFINITY_COST + 1
YuliangLiu0306 marked this conversation as resolved.
Show resolved Hide resolved
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
Expand Down
2 changes: 2 additions & 0 deletions colossalai/auto_parallel/solver/sharding_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def __init__(self, node: Node):
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())

def check_merge(self):
Expand Down
9 changes: 6 additions & 3 deletions colossalai/auto_parallel/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self,
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
Expand All @@ -67,7 +67,7 @@ def _recover_merged_node_strategy(self):
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.graph.nodes):
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
Expand Down Expand Up @@ -297,7 +297,8 @@ def get_non_zero_index(binary_vector):
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])

for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
Expand All @@ -317,12 +318,14 @@ def get_non_zero_index(binary_vector):
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])

#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])

prob += obj
Expand Down
88 changes: 57 additions & 31 deletions colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,21 @@ def build_strategies_and_cost(self):
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()

# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass

elif input_nodes_len == 2:
# one of x or y is type of scalar
pass

else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()

# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
Expand All @@ -218,9 +233,8 @@ def build_strategies_and_cost(self):

# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
if isinstance(node._meta_data, torch.Tensor):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()

# bcast op
elif target in BCAST_FUNC_OP:
Expand Down Expand Up @@ -287,32 +301,34 @@ def build_strategies_and_cost(self):
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
if isinstance(input_tensor_node, torch.Tensor):
for strategy in input_tensor_node.strategies_vector:
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(
name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)

# torch.arange function
elif target == torch.arange:
Expand All @@ -330,8 +346,7 @@ def build_strategies_and_cost(self):
strategies_vector.append(sharding_strategy)

# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
torch.nn.functional.softmax, torch.pow, torch.tanh):
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
Expand All @@ -340,7 +355,7 @@ def build_strategies_and_cost(self):
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size, torch.Tensor.contiguous):
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
Expand Down Expand Up @@ -394,3 +409,14 @@ def build_strategies_and_cost(self):
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)