Skip to content

Commit

Permalink
[autoparallel] add embedding handler
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangLiu0306 committed Sep 21, 2022
1 parent 9772581 commit 77231a5
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 3 deletions.
2 changes: 1 addition & 1 deletion colossalai/auto_parallel/solver/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def exception_handler(func):
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
except AssertionError as e:
warnings.warn(f'{e}')

return wrapper
18 changes: 16 additions & 2 deletions colossalai/auto_parallel/solver/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@

__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP'
]

ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
torch.nn.functional.dropout, torch.flatten
]
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
torch.Tensor.split,
torch.Tensor.permute,
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul
Expand All @@ -23,9 +35,11 @@
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP

Expand Down
176 changes: 176 additions & 0 deletions colossalai/auto_parallel/solver/op_handler/embedding_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import operator
from functools import reduce
import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.solver._utils import exception_handler

__all__ = ['EmbeddingHandler']


class EmbeddingHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data

def _generate_compute_cost(self, total_sharding_size):
input_shape = self.input_data.shape
weight_shape = self.weight.shape
input_shape_product = reduce(operator.mul, input_shape, 1)
weight_shape_product = reduce(operator.mul, weight_shape, 1)
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
return compute_cost

def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()

# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight

# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight

# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)

return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight

@exception_handler
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'

dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {2: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])

# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)

# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)

@exception_handler
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'

dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])

# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)

# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)

# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward_activation = 0
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)

def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
'''
# RRS = RR x SS
self.split_weight_both_dim(0, 1)
self.split_weight_both_dim(1, 0)

# SSR = SS x RR
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)

return self.strategies_vector

0 comments on commit 77231a5

Please sign in to comment.