Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add unary element wise handler v2 #1674

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector
from ..strategy import UnaryElementwiseGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
import operator

__all__ = ['UnaryElementwiseHandler']


@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
class UnaryElementwiseHandler(NodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""

def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "output": physical_output}

return mapping
3 changes: 2 additions & 1 deletion colossalai/auto_parallel/solver/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .layer_norm_generator import LayerNormGenerator

__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'BatchNormStrategyGenerator', 'LayerNormGenerator'
'BatchNormStrategyGenerator', 'UnaryElementwiseGenerator', 'LayerNormGenerator'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy

__all__ = ['UnaryElementwiseGenerator']


class UnaryElementwiseGenerator(FollowingStrategyGenerator):
"""
UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp.
"""

def validate(self) -> bool:
return super().validate()

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
return TrainCycleItem(fwd=10, bwd=10, total=20)

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}

backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)

# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)

# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
return super().update_memory_cost(strategy)

def generate(self):
strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)

for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)

return strategy_list
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh


class ReLuModel(nn.Module):

def __init__(self):
super().__init__()
self.act = torch.nn.ReLU()

def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other)
relu_node = self.act(conv_node)
return relu_node


def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
# return act
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'),
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)

mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[2]
relu_mod_node = list(graph.nodes)[3]
relu_strategies_vector = StrategiesVector(relu_mod_node)
conv_strategies_vector = StrategiesVector(conv_mod_node)

# build handler
conv_handler = ConvFunctionHandler(node=conv_mod_node,
device_mesh=device_mesh,
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy()
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
relu_handler = UnaryElementwiseHandler(node=relu_mod_node,
device_mesh=device_mesh,
strategies_vector=relu_strategies_vector)

relu_handler.register_strategy()

# check operation data mapping
mapping = relu_handler.get_operation_data_mapping()

for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None

assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])

assert mapping['output'].name == "act"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['output'].type == OperationDataType.OUTPUT

# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(relu_strategies_vector) == len(conv_strategies_vector)


if __name__ == '__main__':
test_elementwise_handler()