Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] PoC of runtime shape consistency application #1607

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/auto_parallel/solver/strategies_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def build_strategies_and_cost(self):
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs)
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)

self.remove_duplicated_strategy(strategies_vector)
Expand Down
112 changes: 112 additions & 0 deletions colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from copy import deepcopy


def apply(*args, **kwargs):
shape_consistency_manager = ShapeConsistencyManager()
return shape_consistency_manager.apply(*args, **kwargs)


def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is solution a list of int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution of solver is a list of int, the value of each element stands for the best strategy of the node.

mod_graph = gm.graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a mod_graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model graph

nodes = tuple(mod_graph.nodes)

# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec

# apply the sharding spec of parameters
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why permute?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is because that conv/linear weight is in the desired shape, I can accept it now but we should handle this in NodeHandler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I just find this problem during test, I will fix it in future PR.

apply(target_module.weight, target_weight_sharding_spec)
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))

# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs

# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
break

return sharding_spec_convert_dict, origin_node_sharding_spec_dict


def shape_consistency_pass(gm: torch.fx.GraphModule):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
input_dict_node = None
origin_dict_node = None

# mapping the node into the origin graph index
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
origin_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
continue
node_to_index_dict[node] = index
index += 1
assert input_dict_node is not None

# add shape consistency apply function into graph
for node in nodes:
if not hasattr(node, 'best_strategy'):
continue
with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(origin_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function',
builtins.setattr,
args=(node, 'sharding_spec', origin_spec_node))

for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_dict_node, node_to_index_dict[node]))
with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_specs_node, node_index))
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))

gm.recompile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompile should be only called when all passes finish.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return gm
83 changes: 83 additions & 0 deletions tests/test_auto_parallel/test_shape_consistency_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor

from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.solver import Solver
from colossalai.auto_parallel.solver.options import SolverOptions


class ConvModel(nn.Module):

def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)

def forward(self, x):
x = self.conv(x)
return x


def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 4, 8, 8))

tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
origin_output = model(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
nodes = [node for node in gm.graph.nodes]
output = gm(input, sharding_spec_dict, origin_spec_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such usage is kind of not intuitive, I would recommend to stick to gm(input) in the future but I can let it pass for now. We can annotate with a TODO tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

assert output.equal(origin_output)


@pytest.mark.skip("for higher testing speed")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_apply()