Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add pooling handler #1690

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry

__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']


@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.MaxPool2d)
@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""

def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output}

return mapping
3 changes: 2 additions & 1 deletion colossalai/auto_parallel/solver/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator
from .reshape_generator import ReshapeGenerator
from .normal_pooling_generator import NormalPoolStrategyGenerator

__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator'
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
]
117 changes: 117 additions & 0 deletions colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy


class NormalPoolStrategyGenerator(StrategyGenerator_V2):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
"""

def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'

def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.

Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()

kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * (len(sharded_output_shape) - 2)
kernel_size_product = reduce(operator.mul, kernel_size)
output_size_product = reduce(operator.mul, sharded_output_shape)
input_size_product = reduce(operator.mul, sharded_input_shape)

forward_compute_cost = output_size_product * kernel_size_product
backward_compute_cost = input_size_product * kernel_size_product

total_compute_cost = forward_compute_cost + backward_compute_cost

compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost

def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}

backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)

# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)

# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost

def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}

sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)

name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
communication_action_mapping = {}

strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)

return strategy

def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2))
# append {} for non_split case
dim_partition_list.append({})

return dim_partition_list

def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []

dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)

for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)

return strategy_list
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh


def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})

gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)

mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)

# build handler
handler = NormPoolingHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()

for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None

assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])

assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
assert mapping['output'].type == OperationDataType.OUTPUT

strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 9


if __name__ == '__main__':
test_norm_pool_handler()