Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add resnet autoparallel unit test #1589

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 43 additions & 29 deletions colossalai/auto_parallel/solver/op_handler/conv_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_ac
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)

return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight

def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
Expand All @@ -129,15 +129,18 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost of this strategy
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
# compute the backward communication cost to all reduce the input activation grad
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
mesh_dim_1)
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight

sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
Expand Down Expand Up @@ -173,11 +176,16 @@ def split_input_batch(self, mesh_dim_0):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)

# This strategy do not need to do all_reduce operation in both forward and backward phase.
communication_cost = 0
# This strategy do not need to do all_reduce operation in forward phase.
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute the total cost
communication_cost = communication_cost_forward + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
Expand Down Expand Up @@ -213,15 +221,17 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
# This strategy do not need to do all_reduce operation to compute the input activation grad
communication_cost_backward_activation = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute total cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
Expand Down Expand Up @@ -256,7 +266,7 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# compute the communication cost of this strategy during forward phase
Expand Down Expand Up @@ -298,9 +308,8 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
Expand Down Expand Up @@ -341,7 +350,7 @@ def split_weight_out_channel(self, mesh_dim_0):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# This strategy do not need to do all_reduce during forward phase
Expand Down Expand Up @@ -383,8 +392,8 @@ def non_split(self):
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)

# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
Expand Down Expand Up @@ -424,11 +433,17 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)

# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
# This strategy do not need to do all_reduce in forward phase
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
# compute the total communication cost
communication_cost = communication_cost_backward_weight + communication_cost_forward

sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
Expand Down Expand Up @@ -466,9 +481,8 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
Expand Down
125 changes: 125 additions & 0 deletions tests/test_auto_parallel/test_solver_with_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest

from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser


class ConvModel(nn.Module):

def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.relu = nn.ReLU()

def forward(self, x):
x = x * 2
x = self.conv1(x)
x = self.conv2(x)
x = x / 2
x = self.conv3(x)
x = self.relu(x)
return x


@pytest.mark.skip("for higher testing speed")
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()

tracer = ColoTracer()
# model = ConvModel(16, 32)
# input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
model = resnet50(num_classes=100000)
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}

graph = tracer.trace(root=model, meta_args=input_sample)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
# %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
# %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
# %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
# %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
# %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
# %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
# %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
# %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
# %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
# %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
# %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
# %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
# ...
# %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {})
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# print(len(liveness_dict[0].unique_live_vars))
# assert False
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()

cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)

ret = solver.call_solver_serialized_args()
print(ret)
strategies_list = list(ret[0])
print(strategies_list)
computation_cost = 0
communication_cost = 0
communication_cost_bn = 0
memory_cost = 0
for index, node in enumerate(graph.nodes):
if node.op == 'call_module':
submod = node.graph.owning_module.get_submodule(node.target)
if type(submod) in ELEMENTWISE_MODULE_OP:
input_spec = node.args[0].strategies_vector[strategies_list[index]].output_sharding_spec
print(node.name, input_spec)
continue
if type(submod) in BATCHNORM_MODULE_OP:
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost

print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
print(f'bn communication cost is {communication_cost_bn}')


if __name__ == '__main__':
test_cost_graph()