-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx] PoC of runtime shape consistency application #1607
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import torch | ||
from typing import List | ||
from torch.fx import symbolic_trace | ||
from torch.fx.node import Node | ||
from colossalai.fx.passes.split_module import split_module | ||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager | ||
from colossalai.device.device_mesh import DeviceMesh | ||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec | ||
import builtins | ||
import operator | ||
from copy import deepcopy | ||
|
||
|
||
def apply(*args, **kwargs): | ||
shape_consistency_manager = ShapeConsistencyManager() | ||
return shape_consistency_manager.apply(*args, **kwargs) | ||
|
||
|
||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): | ||
mod_graph = gm.graph | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is a mod_graph? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model graph |
||
nodes = tuple(mod_graph.nodes) | ||
|
||
# the dict to get origin sharding spec of node | ||
origin_node_sharding_spec_dict = {} | ||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): | ||
strategies_vector = node.strategies_vector | ||
setattr(node, 'best_strategy', strategies_vector[strategy_index]) | ||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) | ||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec | ||
|
||
# apply the sharding spec of parameters | ||
for node in nodes: | ||
if node.op == 'call_module': | ||
target_module = node.graph.owning_module.get_submodule(node.target) | ||
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) | ||
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) | ||
target_weight_sharding_spec = node.best_strategy.input_shardings[1] | ||
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why permute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is because that conv/linear weight is in the desired shape, I can accept it now but we should handle this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I just find this problem during test, I will fix it in future PR. |
||
apply(target_module.weight, target_weight_sharding_spec) | ||
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) | ||
|
||
# the dict to get input sharding specs of user node | ||
sharding_spec_convert_dict = {} | ||
for index, node in enumerate(nodes): | ||
target_sharding_specs = [] | ||
for user_node in node.strategies_vector.successor_nodes: | ||
node_index = user_node.strategies_vector.predecessor_nodes.index(node) | ||
target_sharding_spec = user_node.best_strategy.input_shardings[node_index] | ||
target_sharding_specs.append(target_sharding_spec) | ||
sharding_spec_convert_dict[index] = target_sharding_specs | ||
|
||
# add above dicts into graph | ||
for node in nodes: | ||
if node.op != 'placeholder': | ||
with mod_graph.inserting_before(node): | ||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') | ||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') | ||
break | ||
|
||
return sharding_spec_convert_dict, origin_node_sharding_spec_dict | ||
|
||
|
||
def shape_consistency_pass(gm: torch.fx.GraphModule): | ||
mod_graph = gm.graph | ||
nodes = tuple(mod_graph.nodes) | ||
input_dict_node = None | ||
origin_dict_node = None | ||
|
||
# mapping the node into the origin graph index | ||
node_to_index_dict = {} | ||
index = 0 | ||
for node in nodes: | ||
if node.target == 'sharding_spec_convert_dict': | ||
input_dict_node = node | ||
continue | ||
if node.target == 'origin_node_sharding_spec_dict': | ||
origin_dict_node = node | ||
continue | ||
if not hasattr(node, 'best_strategy'): | ||
continue | ||
node_to_index_dict[node] = index | ||
index += 1 | ||
assert input_dict_node is not None | ||
|
||
# add shape consistency apply function into graph | ||
for node in nodes: | ||
if not hasattr(node, 'best_strategy'): | ||
continue | ||
with mod_graph.inserting_after(node): | ||
origin_spec_node = mod_graph.create_node('call_function', | ||
operator.getitem, | ||
args=(origin_dict_node, node_to_index_dict[node])) | ||
with mod_graph.inserting_after(origin_spec_node): | ||
set_sharding_spec_node = mod_graph.create_node('call_function', | ||
builtins.setattr, | ||
args=(node, 'sharding_spec', origin_spec_node)) | ||
|
||
for user_node in node.strategies_vector.successor_nodes: | ||
node_index = user_node.strategies_vector.predecessor_nodes.index(node) | ||
with mod_graph.inserting_before(user_node): | ||
input_specs_node = mod_graph.create_node('call_function', | ||
operator.getitem, | ||
args=(input_dict_node, node_to_index_dict[node])) | ||
with mod_graph.inserting_before(user_node): | ||
sharding_spec_node = mod_graph.create_node('call_function', | ||
operator.getitem, | ||
args=(input_specs_node, node_index)) | ||
with mod_graph.inserting_before(user_node): | ||
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) | ||
|
||
return gm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from functools import partial | ||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
from torch.fx import GraphModule | ||
import torch.nn as nn | ||
import pytest | ||
from colossalai.initialize import launch | ||
from colossalai.utils import free_port | ||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.logging import disable_existing_loggers | ||
from colossalai.auto_parallel.solver.cost_graph import CostGraph | ||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser | ||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor | ||
|
||
from colossalai.fx.tracer.tracer import ColoTracer | ||
from colossalai.device.device_mesh import DeviceMesh | ||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass | ||
from colossalai.auto_parallel.solver import Solver | ||
from colossalai.auto_parallel.solver.options import SolverOptions | ||
|
||
|
||
class ConvModel(nn.Module): | ||
|
||
def __init__(self, c_in, c_out): | ||
super().__init__() | ||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return x | ||
|
||
|
||
def check_apply(rank, world_size, port): | ||
disable_existing_loggers() | ||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
input = torch.rand(4, 4, 4, 4).cuda() | ||
physical_mesh_id = torch.arange(0, 4) | ||
mesh_shape = (2, 2) | ||
# [[0, 1] | ||
# [2, 3]] | ||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) | ||
entire_shape = torch.Size((4, 4, 8, 8)) | ||
|
||
tracer = ColoTracer() | ||
model = ConvModel(4, 4).cuda() | ||
origin_output = model(input) | ||
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} | ||
# graph(): | ||
# %x : torch.Tensor [#users=1] = placeholder[target=x] | ||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) | ||
# return conv | ||
graph = tracer.trace(root=model, meta_args=input_sample) | ||
gm = GraphModule(model, graph, model.__class__.__name__) | ||
gm.recompile() | ||
solver_options = SolverOptions(fast=True) | ||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) | ||
strategies_constructor.build_strategies_and_cost() | ||
|
||
cost_graph = CostGraph(strategies_constructor.leaf_strategies) | ||
cost_graph.simplify_graph() | ||
graph_analyser = GraphAnalyser(gm) | ||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) | ||
ret = solver.call_solver_serialized_args() | ||
solution = list(ret[0]) | ||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) | ||
shape_consistency_pass(gm) | ||
nodes = [node for node in gm.graph.nodes] | ||
# TODO: wrap the gm to avoid the influence of the user training code | ||
output = gm(input, sharding_spec_dict, origin_spec_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Such usage is kind of not intuitive, I would recommend to stick to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
assert output.equal(origin_output) | ||
|
||
|
||
@pytest.mark.skip("for higher testing speed") | ||
@pytest.mark.dist | ||
@rerun_if_address_is_in_use() | ||
def test_apply(): | ||
world_size = 4 | ||
run_func = partial(check_apply, world_size=world_size, port=free_port()) | ||
mp.spawn(run_func, nprocs=world_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_apply() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is solution a list of int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The solution of solver is a list of int, the value of each element stands for the best strategy of the node.