Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] apply repeat block to reduce solving time #2912

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions colossalai/auto_parallel/tensor_shard/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])

Expand Down
39 changes: 25 additions & 14 deletions colossalai/auto_parallel/tensor_shard/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
Expand Down Expand Up @@ -63,7 +63,10 @@ def __init__(self,
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
Expand Down Expand Up @@ -140,7 +143,7 @@ def _prepare_data_for_solver(self):
liveness_set = self.liveness_list

# omit alias_set now
alias_set = None
alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None

# prepare compute_costs, communication_costs and memory_costs
Expand Down Expand Up @@ -230,6 +233,7 @@ def get_non_zero_index(binary_vector):

# 0. Unpack flatten numpy arrays
s_follow = following_nodes
s_alias = alias_set

E = edge_pairs.reshape((-1, 2)) # noqa
r = []
Expand Down Expand Up @@ -294,8 +298,11 @@ def get_non_zero_index(binary_vector):
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
if i not in s_alias:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
Expand All @@ -311,15 +318,20 @@ def get_non_zero_index(binary_vector):
#############################
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
Expand Down Expand Up @@ -371,13 +383,12 @@ def get_non_zero_index(binary_vector):
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget

# (d). specified by `cat="Binary"`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh

from ..options import DataloaderOption, SolverOptions
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
self.alias_set = None

def remove_duplicated_strategy(self, strategies_vector):
'''
Expand All @@ -59,6 +61,22 @@ def remove_duplicated_strategy(self, strategies_vector):
for strategy in remove_list:
strategies_vector.remove(strategy)

def generate_alias_set(self):

node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)

repeat_block_nums = len(common_blocks)
alias_set = {}

if repeat_block_nums == 0:
return alias_set

for index, common_node in enumerate(common_blocks[0]):
for i in range(1, repeat_block_nums):
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
return alias_set

def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
Expand Down Expand Up @@ -175,3 +193,6 @@ def _check_no_strategy_for_data(data):
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)

alias_set = self.generate_alias_set()
self.alias_set = alias_set
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
HIDDEN_DIM = 384


@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
Expand Down Expand Up @@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
Expand Down Expand Up @@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,14 @@ def test_cost_graph():
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()

solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph)

ret = solver.call_solver_serialized_args()
print(ret[0])
Expand Down